This commit is contained in:
2026-01-02 04:02:02 +00:00
commit 3b8e171fdb
117 changed files with 21570 additions and 0 deletions
+437
View File
@@ -0,0 +1,437 @@
package analytics
import (
"sort"
"sync"
"time"
"github.com/rs/zerolog/log"
)
// PackageDownload represents a package download event
type PackageDownload struct {
Registry string
Name string
Version string
Timestamp time.Time
BytesSize int64
ClientIP string
UserAgent string
}
// PackageStats holds statistics for a package
type PackageStats struct {
Registry string
Name string
TotalDownloads int64
UniqueVersions int
LastDownload time.Time
FirstSeen time.Time
BytesServed int64
}
// TrendData represents trend information over time
type TrendData struct {
Period time.Duration
Downloads int64
Packages int
}
// PopularPackage represents a popular package entry
type PopularPackage struct {
Registry string
Name string
Downloads int64
RecentDownloads int64 // Downloads in last 7 days
Trend float64 // Growth rate
}
// Engine tracks and analyzes package downloads
type Engine struct {
downloads []PackageDownload
downloadsMu sync.RWMutex
stats map[string]*PackageStats // key: registry:name
statsMu sync.RWMutex
maxEvents int
flushTicker *time.Ticker
stopChan chan struct{}
}
// Config holds analytics engine configuration
type Config struct {
MaxEvents int
FlushInterval time.Duration
}
// NewEngine creates a new analytics engine
func NewEngine(cfg Config) *Engine {
if cfg.MaxEvents <= 0 {
cfg.MaxEvents = 10000
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Minute
}
engine := &Engine{
downloads: make([]PackageDownload, 0, cfg.MaxEvents),
stats: make(map[string]*PackageStats),
maxEvents: cfg.MaxEvents,
flushTicker: time.NewTicker(cfg.FlushInterval),
stopChan: make(chan struct{}),
}
// Load existing stats from metadata store
engine.loadStats()
// Start background flush goroutine
go engine.flushLoop()
log.Info().
Int("max_events", cfg.MaxEvents).
Dur("flush_interval", cfg.FlushInterval).
Msg("Analytics engine started")
return engine
}
// TrackDownload records a package download event
func (e *Engine) TrackDownload(download PackageDownload) {
e.downloadsMu.Lock()
defer e.downloadsMu.Unlock()
// Add to event buffer
e.downloads = append(e.downloads, download)
// Update in-memory stats
e.updateStats(download)
// Flush if buffer is full
if len(e.downloads) >= e.maxEvents {
go e.flush()
}
log.Debug().
Str("registry", download.Registry).
Str("package", download.Name).
Str("version", download.Version).
Msg("Download tracked")
}
// updateStats updates in-memory statistics
func (e *Engine) updateStats(download PackageDownload) {
e.statsMu.Lock()
defer e.statsMu.Unlock()
key := download.Registry + ":" + download.Name
stats, exists := e.stats[key]
if !exists {
stats = &PackageStats{
Registry: download.Registry,
Name: download.Name,
FirstSeen: download.Timestamp,
}
e.stats[key] = stats
}
stats.TotalDownloads++
stats.BytesServed += download.BytesSize
stats.LastDownload = download.Timestamp
// Track unique versions (simplified)
stats.UniqueVersions++
}
// GetPackageStats returns statistics for a specific package
func (e *Engine) GetPackageStats(registry, name string) (*PackageStats, bool) {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
key := registry + ":" + name
stats, exists := e.stats[key]
if !exists {
return nil, false
}
// Return a copy to avoid race conditions
statsCopy := *stats
return &statsCopy, true
}
// GetTopPackages returns the most downloaded packages
func (e *Engine) GetTopPackages(limit int) []PopularPackage {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
packages := make([]PopularPackage, 0, len(e.stats))
for _, stats := range e.stats {
packages = append(packages, PopularPackage{
Registry: stats.Registry,
Name: stats.Name,
Downloads: stats.TotalDownloads,
})
}
// Sort by downloads descending
sort.Slice(packages, func(i, j int) bool {
return packages[i].Downloads > packages[j].Downloads
})
if limit > 0 && limit < len(packages) {
packages = packages[:limit]
}
return packages
}
// GetTrendingPackages returns packages with growing popularity
func (e *Engine) GetTrendingPackages(limit int) []PopularPackage {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour)
packages := make([]PopularPackage, 0)
for _, stats := range e.stats {
// Calculate recent downloads (last 7 days)
recent := e.getRecentDownloads(stats.Registry, stats.Name, sevenDaysAgo)
// Calculate trend (simple growth rate)
trend := 0.0
if stats.TotalDownloads > 0 {
trend = float64(recent) / float64(stats.TotalDownloads) * 100
}
packages = append(packages, PopularPackage{
Registry: stats.Registry,
Name: stats.Name,
Downloads: stats.TotalDownloads,
RecentDownloads: recent,
Trend: trend,
})
}
// Sort by trend descending
sort.Slice(packages, func(i, j int) bool {
return packages[i].Trend > packages[j].Trend
})
if limit > 0 && limit < len(packages) {
packages = packages[:limit]
}
return packages
}
// getRecentDownloads counts downloads since a given time
func (e *Engine) getRecentDownloads(registry, name string, since time.Time) int64 {
e.downloadsMu.RLock()
defer e.downloadsMu.RUnlock()
count := int64(0)
for _, download := range e.downloads {
if download.Registry == registry &&
download.Name == name &&
download.Timestamp.After(since) {
count++
}
}
return count
}
// GetTrends returns download trends over different time periods
func (e *Engine) GetTrends() []TrendData {
e.downloadsMu.RLock()
defer e.downloadsMu.RUnlock()
now := time.Now()
periods := []time.Duration{
1 * time.Hour,
24 * time.Hour,
7 * 24 * time.Hour,
30 * 24 * time.Hour,
}
trends := make([]TrendData, len(periods))
for i, period := range periods {
since := now.Add(-period)
downloads := int64(0)
packages := make(map[string]bool)
for _, download := range e.downloads {
if download.Timestamp.After(since) {
downloads++
packages[download.Registry+":"+download.Name] = true
}
}
trends[i] = TrendData{
Period: period,
Downloads: downloads,
Packages: len(packages),
}
}
return trends
}
// GetTotalStats returns overall statistics
func (e *Engine) GetTotalStats() map[string]interface{} {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
totalDownloads := int64(0)
totalBytes := int64(0)
registries := make(map[string]int64)
for _, stats := range e.stats {
totalDownloads += stats.TotalDownloads
totalBytes += stats.BytesServed
registries[stats.Registry]++
}
return map[string]interface{}{
"total_packages": len(e.stats),
"total_downloads": totalDownloads,
"total_bytes": totalBytes,
"registries": registries,
}
}
// flushLoop periodically flushes download events to metadata store
func (e *Engine) flushLoop() {
for {
select {
case <-e.flushTicker.C:
e.flush()
case <-e.stopChan:
e.flush() // Final flush
return
}
}
}
// flush persists download events to metadata store
func (e *Engine) flush() {
e.downloadsMu.Lock()
downloads := e.downloads
e.downloads = make([]PackageDownload, 0, e.maxEvents)
e.downloadsMu.Unlock()
if len(downloads) == 0 {
return
}
log.Debug().
Int("events", len(downloads)).
Msg("Flushing analytics events")
// In a real implementation, this would persist to the metadata store
// For now, we just clear the buffer
// TODO: Add actual persistence when metadata store supports analytics tables
}
// loadStats loads existing statistics from metadata store
func (e *Engine) loadStats() {
// TODO: Load stats from metadata store when analytics tables are implemented
log.Debug().Msg("Loading analytics stats from metadata store")
}
// Close stops the analytics engine
func (e *Engine) Close() {
close(e.stopChan)
e.flushTicker.Stop()
e.flush() // Final flush
log.Info().Msg("Analytics engine stopped")
}
// GetRegistryStats returns per-registry statistics
func (e *Engine) GetRegistryStats(registry string) map[string]interface{} {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
totalPackages := 0
totalDownloads := int64(0)
totalBytes := int64(0)
for _, stats := range e.stats {
if stats.Registry == registry {
totalPackages++
totalDownloads += stats.TotalDownloads
totalBytes += stats.BytesServed
}
}
return map[string]interface{}{
"registry": registry,
"total_packages": totalPackages,
"total_downloads": totalDownloads,
"total_bytes": totalBytes,
}
}
// SearchPackages finds packages matching a query
func (e *Engine) SearchPackages(query string, limit int) []PackageStats {
e.statsMu.RLock()
defer e.statsMu.RUnlock()
results := make([]PackageStats, 0)
for _, stats := range e.stats {
// Simple substring search
if contains(stats.Name, query) {
results = append(results, *stats)
}
if len(results) >= limit {
break
}
}
// Sort by downloads
sort.Slice(results, func(i, j int) bool {
return results[i].TotalDownloads > results[j].TotalDownloads
})
return results
}
// contains performs a case-insensitive substring search
func contains(s, substr string) bool {
sLower := toLower(s)
substrLower := toLower(substr)
return len(sLower) >= len(substrLower) &&
findSubstring(sLower, substrLower)
}
func toLower(s string) string {
result := make([]byte, len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if c >= 'A' && c <= 'Z' {
result[i] = c + 32
} else {
result[i] = c
}
}
return string(result)
}
func findSubstring(s, substr string) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if s[i+j] != substr[j] {
match = false
break
}
}
if match {
return true
}
}
return false
}
+393
View File
@@ -0,0 +1,393 @@
package app
import (
"context"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
"github.com/lukaszraczylo/gohoarder/pkg/auth"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/cdn"
"github.com/lukaszraczylo/gohoarder/pkg/config"
"github.com/lukaszraczylo/gohoarder/pkg/health"
"github.com/lukaszraczylo/gohoarder/pkg/lock"
"github.com/lukaszraczylo/gohoarder/pkg/logger"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file"
metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/lukaszraczylo/gohoarder/pkg/prewarming"
"github.com/lukaszraczylo/gohoarder/pkg/proxy/goproxy"
"github.com/lukaszraczylo/gohoarder/pkg/proxy/npm"
"github.com/lukaszraczylo/gohoarder/pkg/proxy/pypi"
"github.com/lukaszraczylo/gohoarder/pkg/scanner"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/lukaszraczylo/gohoarder/pkg/storage/filesystem"
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
"github.com/rs/zerolog/log"
)
// App represents the main application
type App struct {
config *config.Config
server *http.Server
healthChecker *health.Checker
cache *cache.Manager
storage storage.StorageBackend
metadata metadata.Store
authManager *auth.Manager
networkClient *network.Client
scanManager *scanner.Manager
rescanWorker *scanner.RescanWorker
analyticsEngine *analytics.Engine
wsServer *websocket.Server
prewarmWorker *prewarming.Worker
lockManager *lock.Manager
cdnMiddleware *cdn.Middleware
}
// New creates a new application instance
func New(cfg *config.Config) (*App, error) {
app := &App{
config: cfg,
}
// Initialize components
if err := app.initializeComponents(); err != nil {
return nil, err
}
// Setup HTTP server and routes
if err := app.setupServer(); err != nil {
return nil, err
}
return app, nil
}
// initializeComponents initializes all application components
func (a *App) initializeComponents() error {
var err error
// Initialize storage backend
log.Info().Str("backend", a.config.Storage.Backend).Msg("Initializing storage backend")
switch a.config.Storage.Backend {
case "filesystem":
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
default:
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
}
if err != nil {
return fmt.Errorf("failed to initialize storage: %w", err)
}
// Initialize metadata store
log.Info().Str("backend", a.config.Metadata.Backend).Msg("Initializing metadata store")
switch a.config.Metadata.Backend {
case "sqlite":
a.metadata, err = metasqlite.New(metasqlite.Config{
Path: a.config.Metadata.Connection,
})
case "file":
a.metadata, err = metafile.New(metafile.Config{
Path: a.config.Metadata.Connection,
})
default:
a.metadata, err = metasqlite.New(metasqlite.Config{
Path: "gohoarder.db",
})
}
if err != nil {
return fmt.Errorf("failed to initialize metadata: %w", err)
}
// Initialize scanner manager first (before cache)
log.Info().Msg("Initializing security scanner")
a.scanManager, err = scanner.New(a.config.Security, a.metadata)
if err != nil {
return fmt.Errorf("failed to initialize scanner: %w", err)
}
// Initialize cache manager with scanner
log.Info().Msg("Initializing cache manager")
a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, cache.Config{
DefaultTTL: a.config.Cache.DefaultTTL,
CleanupInterval: 5 * time.Minute,
})
if err != nil {
return fmt.Errorf("failed to initialize cache: %w", err)
}
// Initialize network client
log.Info().Msg("Initializing network client")
a.networkClient = network.NewClient(network.Config{
Timeout: 5 * time.Minute,
MaxRetries: 3,
RetryDelay: 1 * time.Second,
RateLimit: 100,
RateBurst: 10,
CircuitBreaker: network.CircuitBreakerConfig{
Enabled: true,
FailureThreshold: 5,
SuccessThreshold: 2,
Timeout: 30 * time.Second,
},
UserAgent: "GoHoarder/1.0",
})
// Initialize authentication manager
log.Info().Msg("Initializing authentication manager")
a.authManager = auth.New()
// Initialize rescan worker if enabled
if a.config.Security.Enabled && a.config.Security.RescanInterval > 0 {
log.Info().Dur("interval", a.config.Security.RescanInterval).Msg("Initializing package rescan worker")
a.rescanWorker = scanner.NewRescanWorker(a.scanManager, a.metadata, a.config.Security.RescanInterval)
}
// Initialize analytics engine
log.Info().Msg("Initializing analytics engine")
a.analyticsEngine = analytics.NewEngine(analytics.Config{
MaxEvents: 10000,
FlushInterval: 5 * time.Minute,
})
// Initialize WebSocket server
log.Info().Msg("Initializing WebSocket server")
a.wsServer = websocket.NewServer(websocket.Config{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins in development
},
})
// Initialize pre-warming worker
log.Info().Msg("Initializing pre-warming worker")
a.prewarmWorker = prewarming.NewWorker(prewarming.Config{
Enabled: false, // Disabled by default
Interval: 1 * time.Hour,
MaxConcurrent: 5,
CacheManager: a.cache,
Analytics: a.analyticsEngine,
NetworkClient: a.networkClient,
})
// Initialize CDN middleware
log.Info().Msg("Initializing CDN middleware")
a.cdnMiddleware = cdn.NewMiddleware(cdn.Config{
DefaultCacheControl: cdn.CacheControl{
Public: true,
MaxAge: 3600,
SMaxAge: 7200,
},
EnableETag: true,
EnableVary: true,
})
// Initialize health checker
a.healthChecker = health.New()
a.healthChecker.AddCheck("storage", func(ctx context.Context) (health.Status, string) {
if err := a.storage.Health(ctx); err != nil {
return health.StatusUnhealthy, err.Error()
}
return health.StatusHealthy, ""
})
a.healthChecker.AddCheck("metadata", func(ctx context.Context) (health.Status, string) {
if err := a.metadata.Health(ctx); err != nil {
return health.StatusUnhealthy, err.Error()
}
return health.StatusHealthy, ""
})
a.healthChecker.AddCheck("cache", func(ctx context.Context) (health.Status, string) {
return health.StatusHealthy, "" // Cache is always healthy if initialized
})
a.healthChecker.AddCheck("scanner", func(ctx context.Context) (health.Status, string) {
if a.config.Security.Enabled {
if err := a.scanManager.Health(ctx); err != nil {
return health.StatusUnhealthy, err.Error()
}
}
return health.StatusHealthy, ""
})
log.Info().Msg("All components initialized successfully")
return nil
}
// setupServer sets up the HTTP server and routes
func (a *App) setupServer() error {
mux := http.NewServeMux()
// Health and metrics endpoints
mux.HandleFunc("/health", a.healthChecker.HealthHandler())
mux.HandleFunc("/health/ready", a.healthChecker.ReadyHandler())
mux.Handle("/metrics", metrics.Handler())
// WebSocket endpoint
mux.HandleFunc("/ws", a.wsServer.HandleWebSocket)
// API endpoints
mux.HandleFunc("/api/packages/", a.handlePackages) // Handles packages and vulnerabilities
mux.HandleFunc("/api/stats", a.handleStats)
mux.HandleFunc("/api/info", a.handleInfo)
// Admin endpoints (bypass management)
mux.HandleFunc("/api/admin/bypasses/", a.handleBypassByID) // Must come before /api/admin/bypasses
mux.HandleFunc("/api/admin/bypasses", a.handleAdminBypasses)
// Proxy handlers
goProxyHandler := goproxy.New(a.cache, a.networkClient, goproxy.Config{
Upstream: "https://proxy.golang.org",
SumDBURL: "https://sum.golang.org",
})
mux.Handle("/go/", http.StripPrefix("/go", goProxyHandler))
npmProxyHandler := npm.New(a.cache, a.networkClient, npm.Config{
Upstream: "https://registry.npmjs.org",
})
mux.Handle("/npm/", http.StripPrefix("/npm", npmProxyHandler))
pypiProxyHandler := pypi.New(a.cache, a.networkClient, pypi.Config{
Upstream: "https://pypi.org/simple",
})
mux.Handle("/pypi/", http.StripPrefix("/pypi", pypiProxyHandler))
// Serve frontend static files
frontendDir := "frontend/dist"
if _, err := os.Stat(frontendDir); err == nil {
log.Info().Str("dir", frontendDir).Msg("Serving frontend static files")
fs := http.FileServer(http.Dir(frontendDir))
mux.Handle("/", fs)
} else {
log.Warn().Msg("Frontend dist directory not found, frontend won't be served")
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprintf(w, `
<html>
<head><title>GoHoarder</title></head>
<body>
<h1>GoHoarder Package Cache Proxy</h1>
<p>Frontend not built. Build with: <code>cd frontend && npm run build</code></p>
<h2>Available Endpoints:</h2>
<ul>
<li><a href="/health">Health Check</a></li>
<li><a href="/metrics">Metrics</a></li>
<li><a href="/api/stats">Statistics API</a></li>
</ul>
</body>
</html>
`)
})
}
// Wrap with logging middleware
handler := logger.Middleware(mux)
// Create HTTP server
a.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port),
Handler: handler,
ReadTimeout: a.config.Server.ReadTimeout,
WriteTimeout: a.config.Server.WriteTimeout,
}
log.Info().
Str("addr", a.server.Addr).
Msg("HTTP server configured")
return nil
}
// Run starts the application
func (a *App) Run() error {
ctx := context.Background()
// Start WebSocket server
a.wsServer.Start(ctx)
// Start pre-warming worker
a.prewarmWorker.Start(ctx)
// Start rescan worker if enabled
if a.rescanWorker != nil {
go a.rescanWorker.Start(ctx)
}
// Start HTTP server in goroutine
errChan := make(chan error, 1)
go func() {
log.Info().
Str("addr", a.server.Addr).
Msg("Starting HTTP server")
if err := a.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errChan <- err
}
}()
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
select {
case err := <-errChan:
return fmt.Errorf("server error: %w", err)
case sig := <-sigChan:
log.Info().
Str("signal", sig.String()).
Msg("Shutdown signal received")
}
// Graceful shutdown
return a.Shutdown()
}
// Shutdown gracefully shuts down the application
func (a *App) Shutdown() error {
log.Info().Msg("Starting graceful shutdown")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Stop HTTP server
if err := a.server.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Error shutting down HTTP server")
}
// Stop pre-warming worker
a.prewarmWorker.Stop()
// Stop rescan worker if running
if a.rescanWorker != nil {
a.rescanWorker.Stop()
}
// Close analytics engine
a.analyticsEngine.Close()
// Close storage
if err := a.storage.Close(); err != nil {
log.Error().Err(err).Msg("Error closing storage")
}
// Close metadata store
if err := a.metadata.Close(); err != nil {
log.Error().Err(err).Msg("Error closing metadata store")
}
// Close lock manager if initialized
if a.lockManager != nil {
if err := a.lockManager.Close(); err != nil {
log.Error().Err(err).Msg("Error closing lock manager")
}
}
log.Info().Msg("Shutdown complete")
return nil
}
+415
View File
@@ -0,0 +1,415 @@
package app
import (
"net/http"
"strings"
"time"
"github.com/lukaszraczylo/gohoarder/internal/version"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
"github.com/rs/zerolog/log"
)
// handlePackages handles /api/packages endpoint
func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// Check if this is a vulnerability endpoint request
if strings.HasSuffix(r.URL.Path, "/vulnerabilities") {
a.handleVulnerabilities(w, r)
return
}
switch r.Method {
case "GET":
a.handleListPackages(w, r)
case "DELETE":
a.handleDeletePackage(w, r)
default:
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
}
}
// handleListPackages returns list of cached packages
func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Get packages from metadata store
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000, // Get more to account for duplicates
Offset: 0,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
// Filter, clean, and deduplicate packages
seen := make(map[string]*metadata.Package)
for _, pkg := range allPackages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
// Clean the package name (remove /@v/version.ext suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
// Create deduplication key
key := cleanName + "@" + pkg.Version
// Keep the entry with the largest size (typically .zip files)
if existing, ok := seen[key]; !ok || pkg.Size > existing.Size {
// Create a copy with cleaned name
cleanPkg := *pkg
cleanPkg.Name = cleanName
seen[key] = &cleanPkg
}
}
// Convert map to slice
packages := make([]*metadata.Package, 0, len(seen))
for _, pkg := range seen {
packages = append(packages, pkg)
}
// Enhance packages with vulnerability information if security scanning is enabled
var response map[string]interface{}
if a.config.Security.Enabled {
enhancedPackages := make([]map[string]interface{}, 0, len(packages))
for _, pkg := range packages {
pkgMap := map[string]interface{}{
"id": pkg.ID,
"registry": pkg.Registry,
"name": pkg.Name,
"version": pkg.Version,
"size": pkg.Size,
"checksum_sha256": pkg.ChecksumSHA256,
"cached_at": pkg.CachedAt,
"last_accessed": pkg.LastAccessed,
"download_count": pkg.DownloadCount,
}
// Add vulnerability info if scanned
if pkg.SecurityScanned {
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
if err == nil && scanResult != nil {
// Count vulnerabilities by severity
severityCounts := make(map[string]int)
for _, vuln := range scanResult.Vulnerabilities {
severityCounts[strings.ToUpper(vuln.Severity)]++
}
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": true,
"status": scanResult.Status,
"scannedAt": scanResult.ScannedAt.Format(time.RFC3339),
"counts": map[string]int{
"critical": severityCounts["CRITICAL"],
"high": severityCounts["HIGH"],
"medium": severityCounts["MEDIUM"],
"low": severityCounts["LOW"],
},
"total": scanResult.VulnerabilityCount,
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "pending",
}
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "not_scanned",
}
}
enhancedPackages = append(enhancedPackages, pkgMap)
}
response = map[string]interface{}{
"packages": enhancedPackages,
"total": len(enhancedPackages),
}
} else {
response = map[string]interface{}{
"packages": packages,
"total": len(packages),
}
}
// Success response
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleDeletePackage deletes a cached package
func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Parse path: /api/packages/{registry}/{name}/{version}
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
// Version is always the last segment
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
parts := strings.Split(path, "/")
if len(parts) < 3 {
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}"))
return
}
registry := parts[0]
version := parts[len(parts)-1]
name := strings.Join(parts[1:len(parts)-1], "/")
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
// For other registries, we can delete directly
var deletedCount int
var lastErr error
if registry == "go" {
// List all packages matching the base name and version
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages for deletion")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
log.Debug().
Str("registry", registry).
Str("name", name).
Str("version", version).
Int("total_packages", len(allPackages)).
Msg("Searching for packages to delete")
// Find and delete all entries for this package
for _, pkg := range allPackages {
if pkg.Registry != registry || pkg.Version != version {
continue
}
// Check if this package name matches (either exact or with /@v/ suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
log.Debug().
Str("db_name", pkg.Name).
Str("clean_name", cleanName).
Str("search_name", name).
Bool("matches", cleanName == name).
Msg("Checking package")
if cleanName == name {
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
log.Warn().
Err(err).
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Failed to delete package variant")
lastErr = err
} else {
deletedCount++
log.Info().
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Deleted package variant")
}
}
}
log.Debug().
Int("deleted_count", deletedCount).
Msg("Delete operation completed")
if deletedCount == 0 {
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
return
}
if lastErr != nil && deletedCount == 0 {
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
} else {
// For NPM and PyPI, delete directly
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
log.Error().
Err(err).
Str("registry", registry).
Str("name", name).
Str("version", version).
Msg("Failed to delete package")
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
deletedCount = 1
}
// Broadcast event via WebSocket
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
"registry": registry,
"name": name,
"version": version,
})
// Success response
response := map[string]interface{}{
"deleted": true,
"package": map[string]string{
"registry": registry,
"name": name,
"version": version,
},
}
// For Go packages, include count of deleted variants
if registry == "go" {
response["deleted_count"] = deletedCount
}
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleStats handles /api/stats endpoint
func (a *App) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
ctx := r.Context()
// Get cache statistics for all registries
cacheStats, err := a.cache.GetStats(ctx, "")
if err != nil {
log.Error().Err(err).Msg("Failed to get cache stats")
cacheStats = &metadata.Stats{}
}
// Get all packages to calculate total size and downloads
packages, err := a.metadata.ListPackages(ctx, nil)
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
packages = []*metadata.Package{}
}
// Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest")
var totalSize int64
var totalDownloads int64
var actualPackageCount int
registryStats := make(map[string]map[string]interface{})
for _, pkg := range packages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
totalSize += pkg.Size
totalDownloads += int64(pkg.DownloadCount)
actualPackageCount++
// Track per-registry stats
if _, ok := registryStats[pkg.Registry]; !ok {
registryStats[pkg.Registry] = map[string]interface{}{
"count": 0,
"size": int64(0),
"downloads": int64(0),
}
}
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
}
// Combine statistics
stats := map[string]interface{}{
"total_packages": actualPackageCount,
"total_downloads": totalDownloads,
"total_size": totalSize,
"cache_hits": cacheStats.TotalDownloads,
"cache_misses": 0, // TODO: Track cache misses
"cache_evictions": 0, // TODO: Track evictions
"cache_size": cacheStats.TotalSize,
"scanned_packages": cacheStats.ScannedPackages,
"vulnerable_packages": cacheStats.VulnerablePackages,
}
// Convert registry stats to interface map
registries := make(map[string]interface{})
for registry, regStats := range registryStats {
registries[registry] = regStats
}
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"stats": stats,
"registries": registries,
})
}
// handleInfo handles /api/info endpoint
func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
info := map[string]interface{}{
"name": "GoHoarder",
"version": version.Version,
"config": map[string]interface{}{
"storage_backend": a.config.Storage.Backend,
"metadata_backend": a.config.Metadata.Backend,
"cache_ttl": a.config.Cache.DefaultTTL.String(),
"max_cache_size": a.config.Cache.MaxSizeBytes,
},
"features": map[string]bool{
"distributed_locking": a.lockManager != nil,
"security_scanning": a.config.Security.Enabled,
"pre_warming": a.prewarmWorker != nil,
"websockets": true,
"analytics": true,
},
}
errors.WriteJSONSimple(w, http.StatusOK, info)
}
+413
View File
@@ -0,0 +1,413 @@
package app
import (
"net/http"
"strings"
"github.com/lukaszraczylo/gohoarder/internal/version"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
"github.com/rs/zerolog/log"
)
// handlePackages handles /api/packages endpoint
func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// Check if this is a vulnerability endpoint request
if strings.HasSuffix(r.URL.Path, "/vulnerabilities") {
a.handleVulnerabilities(w, r)
return
}
switch r.Method {
case "GET":
a.handleListPackages(w, r)
case "DELETE":
a.handleDeletePackage(w, r)
default:
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
}
}
// handleListPackages returns list of cached packages
func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Get packages from metadata store
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000, // Get more to account for duplicates
Offset: 0,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
// Filter, clean, and deduplicate packages
seen := make(map[string]*metadata.Package)
for _, pkg := range allPackages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
// Clean the package name (remove /@v/version.ext suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
// Create deduplication key
key := cleanName + "@" + pkg.Version
// Keep the entry with the largest size (typically .zip files)
if existing, ok := seen[key]; !ok || pkg.Size > existing.Size {
// Create a copy with cleaned name
cleanPkg := *pkg
cleanPkg.Name = cleanName
seen[key] = &cleanPkg
}
}
// Convert map to slice
packages := make([]*metadata.Package, 0, len(seen))
for _, pkg := range seen {
packages = append(packages, pkg)
}
// Enhance packages with vulnerability information if security scanning is enabled
var response map[string]interface{}
if a.config.Security.Enabled {
enhancedPackages := make([]map[string]interface{}, 0, len(packages))
for _, pkg := range packages {
pkgMap := map[string]interface{}{
"id": pkg.ID,
"registry": pkg.Registry,
"name": pkg.Name,
"version": pkg.Version,
"size": pkg.Size,
"checksum_sha256": pkg.ChecksumSHA256,
"cached_at": pkg.CachedAt,
"last_accessed": pkg.LastAccessed,
"download_count": pkg.DownloadCount,
}
// Add vulnerability info if scanned
if pkg.SecurityScanned {
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
if err == nil && scanResult != nil {
// Count vulnerabilities by severity
severityCounts := make(map[string]int)
for _, vuln := range scanResult.Vulnerabilities {
severityCounts[strings.ToUpper(vuln.Severity)]++
}
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": true,
"status": scanResult.Status,
"counts": map[string]int{
"critical": severityCounts["CRITICAL"],
"high": severityCounts["HIGH"],
"medium": severityCounts["MEDIUM"],
"low": severityCounts["LOW"],
},
"total": scanResult.VulnerabilityCount,
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "pending",
}
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "not_scanned",
}
}
enhancedPackages = append(enhancedPackages, pkgMap)
}
response = map[string]interface{}{
"packages": enhancedPackages,
"total": len(enhancedPackages),
}
} else {
response = map[string]interface{}{
"packages": packages,
"total": len(packages),
}
}
// Success response
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleDeletePackage deletes a cached package
func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Parse path: /api/packages/{registry}/{name}/{version}
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
// Version is always the last segment
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
parts := strings.Split(path, "/")
if len(parts) < 3 {
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}"))
return
}
registry := parts[0]
version := parts[len(parts)-1]
name := strings.Join(parts[1:len(parts)-1], "/")
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
// For other registries, we can delete directly
var deletedCount int
var lastErr error
if registry == "go" {
// List all packages matching the base name and version
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages for deletion")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
log.Debug().
Str("registry", registry).
Str("name", name).
Str("version", version).
Int("total_packages", len(allPackages)).
Msg("Searching for packages to delete")
// Find and delete all entries for this package
for _, pkg := range allPackages {
if pkg.Registry != registry || pkg.Version != version {
continue
}
// Check if this package name matches (either exact or with /@v/ suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
log.Debug().
Str("db_name", pkg.Name).
Str("clean_name", cleanName).
Str("search_name", name).
Bool("matches", cleanName == name).
Msg("Checking package")
if cleanName == name {
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
log.Warn().
Err(err).
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Failed to delete package variant")
lastErr = err
} else {
deletedCount++
log.Info().
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Deleted package variant")
}
}
}
log.Debug().
Int("deleted_count", deletedCount).
Msg("Delete operation completed")
if deletedCount == 0 {
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
return
}
if lastErr != nil && deletedCount == 0 {
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
} else {
// For NPM and PyPI, delete directly
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
log.Error().
Err(err).
Str("registry", registry).
Str("name", name).
Str("version", version).
Msg("Failed to delete package")
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
deletedCount = 1
}
// Broadcast event via WebSocket
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
"registry": registry,
"name": name,
"version": version,
})
// Success response
response := map[string]interface{}{
"deleted": true,
"package": map[string]string{
"registry": registry,
"name": name,
"version": version,
},
}
// For Go packages, include count of deleted variants
if registry == "go" {
response["deleted_count"] = deletedCount
}
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleStats handles /api/stats endpoint
func (a *App) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
ctx := r.Context()
// Get cache statistics for all registries
cacheStats, err := a.cache.GetStats(ctx, "")
if err != nil {
log.Error().Err(err).Msg("Failed to get cache stats")
cacheStats = &metadata.Stats{}
}
// Get all packages to calculate total size and downloads
packages, err := a.metadata.ListPackages(ctx, nil)
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
packages = []*metadata.Package{}
}
// Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest")
var totalSize int64
var totalDownloads int64
var actualPackageCount int
registryStats := make(map[string]map[string]interface{})
for _, pkg := range packages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
totalSize += pkg.Size
totalDownloads += int64(pkg.DownloadCount)
actualPackageCount++
// Track per-registry stats
if _, ok := registryStats[pkg.Registry]; !ok {
registryStats[pkg.Registry] = map[string]interface{}{
"count": 0,
"size": int64(0),
"downloads": int64(0),
}
}
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
}
// Combine statistics
stats := map[string]interface{}{
"total_packages": actualPackageCount,
"total_downloads": totalDownloads,
"total_size": totalSize,
"cache_hits": cacheStats.TotalDownloads,
"cache_misses": 0, // TODO: Track cache misses
"cache_evictions": 0, // TODO: Track evictions
"cache_size": cacheStats.TotalSize,
"scanned_packages": cacheStats.ScannedPackages,
"vulnerable_packages": cacheStats.VulnerablePackages,
}
// Convert registry stats to interface map
registries := make(map[string]interface{})
for registry, regStats := range registryStats {
registries[registry] = regStats
}
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"stats": stats,
"registries": registries,
})
}
// handleInfo handles /api/info endpoint
func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
info := map[string]interface{}{
"name": "GoHoarder",
"version": version.Version,
"config": map[string]interface{}{
"storage_backend": a.config.Storage.Backend,
"metadata_backend": a.config.Metadata.Backend,
"cache_ttl": a.config.Cache.DefaultTTL.String(),
"max_cache_size": a.config.Cache.MaxSizeBytes,
},
"features": map[string]bool{
"distributed_locking": a.lockManager != nil,
"security_scanning": a.config.Security.Enabled,
"pre_warming": a.prewarmWorker != nil,
"websockets": true,
"analytics": true,
},
}
errors.WriteJSONSimple(w, http.StatusOK, info)
}
+415
View File
@@ -0,0 +1,415 @@
package app
import (
"net/http"
"strings"
"time"
"github.com/lukaszraczylo/gohoarder/internal/version"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
"github.com/rs/zerolog/log"
)
// handlePackages handles /api/packages endpoint
func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// Check if this is a vulnerability endpoint request
if strings.HasSuffix(r.URL.Path, "/vulnerabilities") {
a.handleVulnerabilities(w, r)
return
}
switch r.Method {
case "GET":
a.handleListPackages(w, r)
case "DELETE":
a.handleDeletePackage(w, r)
default:
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
}
}
// handleListPackages returns list of cached packages
func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Get packages from metadata store
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000, // Get more to account for duplicates
Offset: 0,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
// Filter, clean, and deduplicate packages
seen := make(map[string]*metadata.Package)
for _, pkg := range allPackages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
// Clean the package name (remove /@v/version.ext suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
// Create deduplication key
key := cleanName + "@" + pkg.Version
// Keep the entry with the largest size (typically .zip files)
if existing, ok := seen[key]; !ok || pkg.Size > existing.Size {
// Create a copy with cleaned name
cleanPkg := *pkg
cleanPkg.Name = cleanName
seen[key] = &cleanPkg
}
}
// Convert map to slice
packages := make([]*metadata.Package, 0, len(seen))
for _, pkg := range seen {
packages = append(packages, pkg)
}
// Enhance packages with vulnerability information if security scanning is enabled
var response map[string]interface{}
if a.config.Security.Enabled {
enhancedPackages := make([]map[string]interface{}, 0, len(packages))
for _, pkg := range packages {
pkgMap := map[string]interface{}{
"id": pkg.ID,
"registry": pkg.Registry,
"name": pkg.Name,
"version": pkg.Version,
"size": pkg.Size,
"checksum_sha256": pkg.ChecksumSHA256,
"cached_at": pkg.CachedAt,
"last_accessed": pkg.LastAccessed,
"download_count": pkg.DownloadCount,
}
// Add vulnerability info if scanned
if pkg.SecurityScanned {
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
if err == nil && scanResult != nil {
// Count vulnerabilities by severity
severityCounts := make(map[string]int)
for _, vuln := range scanResult.Vulnerabilities {
severityCounts[strings.ToUpper(vuln.Severity)]++
}
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": true,
"status": scanResult.Status,
"scannedAt": scanResult.ScannedAt.Format(time.RFC3339),
"counts": map[string]int{
"critical": severityCounts["CRITICAL"],
"high": severityCounts["HIGH"],
"medium": severityCounts["MEDIUM"],
"low": severityCounts["LOW"],
},
"total": scanResult.VulnerabilityCount,
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "pending",
}
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "not_scanned",
}
}
enhancedPackages = append(enhancedPackages, pkgMap)
}
response = map[string]interface{}{
"packages": enhancedPackages,
"total": len(enhancedPackages),
}
} else {
response = map[string]interface{}{
"packages": packages,
"total": len(packages),
}
}
// Success response
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleDeletePackage deletes a cached package
func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Parse path: /api/packages/{registry}/{name}/{version}
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
// Version is always the last segment
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
parts := strings.Split(path, "/")
if len(parts) < 3 {
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}"))
return
}
registry := parts[0]
version := parts[len(parts)-1]
name := strings.Join(parts[1:len(parts)-1], "/")
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
// For other registries, we can delete directly
var deletedCount int
var lastErr error
if registry == "go" {
// List all packages matching the base name and version
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
Limit: 1000,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages for deletion")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
return
}
log.Debug().
Str("registry", registry).
Str("name", name).
Str("version", version).
Int("total_packages", len(allPackages)).
Msg("Searching for packages to delete")
// Find and delete all entries for this package
for _, pkg := range allPackages {
if pkg.Registry != registry || pkg.Version != version {
continue
}
// Check if this package name matches (either exact or with /@v/ suffix)
cleanName := pkg.Name
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
cleanName = cleanName[:idx]
}
log.Debug().
Str("db_name", pkg.Name).
Str("clean_name", cleanName).
Str("search_name", name).
Bool("matches", cleanName == name).
Msg("Checking package")
if cleanName == name {
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
log.Warn().
Err(err).
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Failed to delete package variant")
lastErr = err
} else {
deletedCount++
log.Info().
Str("registry", pkg.Registry).
Str("name", pkg.Name).
Str("version", pkg.Version).
Msg("Deleted package variant")
}
}
}
log.Debug().
Int("deleted_count", deletedCount).
Msg("Delete operation completed")
if deletedCount == 0 {
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
return
}
if lastErr != nil && deletedCount == 0 {
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
} else {
// For NPM and PyPI, delete directly
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
log.Error().
Err(err).
Str("registry", registry).
Str("name", name).
Str("version", version).
Msg("Failed to delete package")
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
return
}
deletedCount = 1
}
// Broadcast event via WebSocket
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
"registry": registry,
"name": name,
"version": version,
})
// Success response
response := map[string]interface{}{
"deleted": true,
"package": map[string]string{
"registry": registry,
"name": name,
"version": version,
},
}
// For Go packages, include count of deleted variants
if registry == "go" {
response["deleted_count"] = deletedCount
}
errors.WriteJSONSimple(w, http.StatusOK, response)
}
// handleStats handles /api/stats endpoint
func (a *App) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
ctx := r.Context()
// Get cache statistics for all registries
cacheStats, err := a.cache.GetStats(ctx, "")
if err != nil {
log.Error().Err(err).Msg("Failed to get cache stats")
cacheStats = &metadata.Stats{}
}
// Get all packages to calculate total size and downloads
packages, err := a.metadata.ListPackages(ctx, nil)
if err != nil {
log.Error().Err(err).Msg("Failed to list packages")
packages = []*metadata.Package{}
}
// Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest")
var totalSize int64
var totalDownloads int64
var actualPackageCount int
registryStats := make(map[string]map[string]interface{})
for _, pkg := range packages {
// Skip metadata entries
if pkg.Version == "list" || pkg.Version == "latest" {
continue
}
totalSize += pkg.Size
totalDownloads += int64(pkg.DownloadCount)
actualPackageCount++
// Track per-registry stats
if _, ok := registryStats[pkg.Registry]; !ok {
registryStats[pkg.Registry] = map[string]interface{}{
"count": 0,
"size": int64(0),
"downloads": int64(0),
}
}
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
}
// Combine statistics
stats := map[string]interface{}{
"total_packages": actualPackageCount,
"total_downloads": totalDownloads,
"total_size": totalSize,
"cache_hits": cacheStats.TotalDownloads,
"cache_misses": 0, // TODO: Track cache misses
"cache_evictions": 0, // TODO: Track evictions
"cache_size": cacheStats.TotalSize,
"scanned_packages": cacheStats.ScannedPackages,
"vulnerable_packages": cacheStats.VulnerablePackages,
}
// Convert registry stats to interface map
registries := make(map[string]interface{})
for registry, regStats := range registryStats {
registries[registry] = regStats
}
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"stats": stats,
"registries": registries,
})
}
// handleInfo handles /api/info endpoint
func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
info := map[string]interface{}{
"name": "GoHoarder",
"version": version.Version,
"config": map[string]interface{}{
"storage_backend": a.config.Storage.Backend,
"metadata_backend": a.config.Metadata.Backend,
"cache_ttl": a.config.Cache.DefaultTTL.String(),
"max_cache_size": a.config.Cache.MaxSizeBytes,
},
"features": map[string]bool{
"distributed_locking": a.lockManager != nil,
"security_scanning": a.config.Security.Enabled,
"pre_warming": a.prewarmWorker != nil,
"websockets": true,
"analytics": true,
},
}
errors.WriteJSONSimple(w, http.StatusOK, info)
}
+378
View File
@@ -0,0 +1,378 @@
package app
import (
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/auth"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
"github.com/rs/zerolog/log"
)
// requireAdmin middleware checks for admin authentication
func (a *App) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get API key from Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "missing authorization header"))
return
}
// Extract bearer token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "invalid authorization header format, expected: Bearer <token>"))
return
}
apiKey := parts[1]
// Validate API key
key, err := a.authManager.ValidateAPIKey(r.Context(), apiKey)
if err != nil {
errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "invalid or expired API key"))
return
}
// Check if user has admin role or bypass management permission
if key.Role != auth.RoleAdmin && !key.HasPermission(auth.PermissionManageBypasses) {
errors.WriteErrorSimple(w, errors.New(errors.ErrCodeForbidden, "insufficient permissions, admin role required"))
return
}
// Store user info in request context for handlers to use
// For now, we'll just proceed - could enhance with context.WithValue
next(w, r)
}
}
// handleAdminBypasses handles /api/admin/bypasses endpoint
func (a *App) handleAdminBypasses(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
switch r.Method {
case "GET":
a.requireAdmin(a.handleListBypasses)(w, r)
case "POST":
a.requireAdmin(a.handleCreateBypass)(w, r)
default:
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
}
}
// handleBypassByID handles /api/admin/bypasses/{id} endpoint
func (a *App) handleBypassByID(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
switch r.Method {
case "GET":
a.requireAdmin(a.handleGetBypass)(w, r)
case "DELETE":
a.requireAdmin(a.handleDeleteBypass)(w, r)
case "PATCH":
a.requireAdmin(a.handleUpdateBypass)(w, r)
default:
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
}
}
// handleListBypasses lists all CVE bypasses
func (a *App) handleListBypasses(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Parse query parameters
includeExpired := r.URL.Query().Get("include_expired") == "true"
activeOnly := r.URL.Query().Get("active_only") == "true"
bypassType := metadata.BypassType(r.URL.Query().Get("type"))
opts := &metadata.BypassListOptions{
IncludeExpired: includeExpired,
ActiveOnly: activeOnly,
Type: bypassType,
}
bypasses, err := a.metadata.ListCVEBypasses(ctx, opts)
if err != nil {
log.Error().Err(err).Msg("Failed to list CVE bypasses")
errors.WriteErrorSimple(w, errors.InternalServer("failed to list bypasses"))
return
}
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"bypasses": bypasses,
"total": len(bypasses),
})
}
// CreateBypassRequest represents the request body for creating a bypass
type CreateBypassRequest struct {
Type metadata.BypassType `json:"type"` // "cve" or "package"
Target string `json:"target"` // CVE ID or package name
Reason string `json:"reason"` // Why this bypass is needed
CreatedBy string `json:"created_by"` // Admin username
ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration
AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package
NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired
}
// handleCreateBypass creates a new CVE bypass
func (a *App) handleCreateBypass(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Parse request body
body, err := io.ReadAll(r.Body)
if err != nil {
errors.WriteErrorSimple(w, errors.BadRequest("failed to read request body"))
return
}
defer r.Body.Close()
var req CreateBypassRequest
if err := json.Unmarshal(body, &req); err != nil {
errors.WriteErrorSimple(w, errors.BadRequest("invalid JSON in request body"))
return
}
// Validate request
if req.Type != metadata.BypassTypeCVE && req.Type != metadata.BypassTypePackage {
errors.WriteErrorSimple(w, errors.BadRequest("type must be 'cve' or 'package'"))
return
}
if req.Target == "" {
errors.WriteErrorSimple(w, errors.BadRequest("target is required"))
return
}
if req.Reason == "" {
errors.WriteErrorSimple(w, errors.BadRequest("reason is required"))
return
}
if req.CreatedBy == "" {
errors.WriteErrorSimple(w, errors.BadRequest("created_by is required"))
return
}
if req.ExpiresInHours <= 0 {
errors.WriteErrorSimple(w, errors.BadRequest("expires_in_hours must be greater than 0"))
return
}
// Create bypass
now := time.Now()
expiresAt := now.Add(time.Duration(req.ExpiresInHours) * time.Hour)
bypass := &metadata.CVEBypass{
ID: uuid.New().String(),
Type: req.Type,
Target: req.Target,
Reason: req.Reason,
CreatedBy: req.CreatedBy,
CreatedAt: now,
ExpiresAt: expiresAt,
AppliesTo: req.AppliesTo,
NotifyOnExpiry: req.NotifyOnExpiry,
Active: true,
}
// Save to database
if err := a.metadata.SaveCVEBypass(ctx, bypass); err != nil {
log.Error().Err(err).Msg("Failed to save CVE bypass")
errors.WriteErrorSimple(w, errors.InternalServer("failed to create bypass"))
return
}
log.Info().
Str("bypass_id", bypass.ID).
Str("type", string(bypass.Type)).
Str("target", bypass.Target).
Str("created_by", bypass.CreatedBy).
Time("expires_at", bypass.ExpiresAt).
Msg("CVE bypass created")
errors.WriteJSONSimple(w, http.StatusCreated, map[string]interface{}{
"bypass": bypass,
"message": "Bypass created successfully",
})
}
// handleGetBypass gets a specific bypass by ID
func (a *App) handleGetBypass(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract ID from path
path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/")
bypassID := path
if bypassID == "" {
errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required"))
return
}
// Get all bypasses and find the one with matching ID
bypasses, err := a.metadata.ListCVEBypasses(ctx, &metadata.BypassListOptions{
IncludeExpired: true,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list bypasses")
errors.WriteErrorSimple(w, errors.InternalServer("failed to get bypass"))
return
}
for _, bypass := range bypasses {
if bypass.ID == bypassID {
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"bypass": bypass,
})
return
}
}
errors.WriteErrorSimple(w, errors.NotFound("bypass not found"))
}
// UpdateBypassRequest represents the request body for updating a bypass
type UpdateBypassRequest struct {
Active *bool `json:"active,omitempty"`
Reason string `json:"reason,omitempty"`
ExpiresInHours int `json:"expires_in_hours,omitempty"`
}
// handleUpdateBypass updates a bypass (activate/deactivate or extend expiration)
func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract ID from path
path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/")
bypassID := path
if bypassID == "" {
errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required"))
return
}
// Parse request body
body, err := io.ReadAll(r.Body)
if err != nil {
errors.WriteErrorSimple(w, errors.BadRequest("failed to read request body"))
return
}
defer r.Body.Close()
var req UpdateBypassRequest
if err := json.Unmarshal(body, &req); err != nil {
errors.WriteErrorSimple(w, errors.BadRequest("invalid JSON in request body"))
return
}
// Get current bypass
bypasses, err := a.metadata.ListCVEBypasses(ctx, &metadata.BypassListOptions{
IncludeExpired: true,
})
if err != nil {
log.Error().Err(err).Msg("Failed to list bypasses")
errors.WriteErrorSimple(w, errors.InternalServer("failed to get bypass"))
return
}
var currentBypass *metadata.CVEBypass
for _, bypass := range bypasses {
if bypass.ID == bypassID {
currentBypass = bypass
break
}
}
if currentBypass == nil {
errors.WriteErrorSimple(w, errors.NotFound("bypass not found"))
return
}
// Update fields
if req.Active != nil {
currentBypass.Active = *req.Active
}
if req.Reason != "" {
currentBypass.Reason = req.Reason
}
if req.ExpiresInHours > 0 {
currentBypass.ExpiresAt = time.Now().Add(time.Duration(req.ExpiresInHours) * time.Hour)
}
// Save updated bypass
if err := a.metadata.SaveCVEBypass(ctx, currentBypass); err != nil {
log.Error().Err(err).Msg("Failed to update bypass")
errors.WriteErrorSimple(w, errors.InternalServer("failed to update bypass"))
return
}
log.Info().
Str("bypass_id", currentBypass.ID).
Bool("active", currentBypass.Active).
Msg("CVE bypass updated")
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"bypass": currentBypass,
"message": "Bypass updated successfully",
})
}
// handleDeleteBypass deletes a bypass
func (a *App) handleDeleteBypass(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract ID from path
path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/")
bypassID := path
if bypassID == "" {
errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required"))
return
}
// Delete bypass
if err := a.metadata.DeleteCVEBypass(ctx, bypassID); err != nil {
if strings.Contains(err.Error(), "not found") {
errors.WriteErrorSimple(w, errors.NotFound("bypass not found"))
} else {
log.Error().Err(err).Msg("Failed to delete bypass")
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete bypass"))
}
return
}
log.Info().
Str("bypass_id", bypassID).
Msg("CVE bypass deleted")
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"deleted": true,
"bypass_id": bypassID,
"message": "Bypass deleted successfully",
})
}
+160
View File
@@ -0,0 +1,160 @@
package app
import (
"net/http"
"strings"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/rs/zerolog/log"
)
// handleVulnerabilities handles /api/packages/{registry}/{name}/{version}/vulnerabilities endpoint
func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "GET" {
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
return
}
ctx := r.Context()
// Parse path: /api/packages/{registry}/{name}/{version}/vulnerabilities
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
path = strings.TrimSuffix(path, "/vulnerabilities")
parts := strings.Split(path, "/")
if len(parts) < 3 {
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}/vulnerabilities"))
return
}
registry := parts[0]
version := parts[len(parts)-1]
name := strings.Join(parts[1:len(parts)-1], "/")
log.Debug().
Str("registry", registry).
Str("name", name).
Str("version", version).
Msg("Getting vulnerabilities for package")
// Get scan result from metadata store
scanResult, err := a.metadata.GetScanResult(ctx, registry, name, version)
if err != nil {
// Check if package exists
pkg, pkgErr := a.metadata.GetPackage(ctx, registry, name, version)
if pkgErr != nil {
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
return
}
// Package exists but not scanned yet
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
"package": map[string]string{
"registry": registry,
"name": name,
"version": version,
},
"scanned": false,
"status": "pending",
"vulnerabilities": []interface{}{},
"vulnerability_count": 0,
"message": "Package not yet scanned for vulnerabilities",
"security_scanned": pkg.SecurityScanned,
})
return
}
// Get active bypasses to show which vulnerabilities are bypassed
bypasses, err := a.metadata.GetActiveCVEBypasses(ctx)
if err != nil {
log.Warn().Err(err).Msg("Failed to get CVE bypasses")
bypasses = []*metadata.CVEBypass{}
}
// Build bypass map for fast lookup
bypassedCVEs := make(map[string]*metadata.CVEBypass)
packageKey := registry + "/" + name + "@" + version
packageKeyNoVersion := registry + "/" + name
for _, bypass := range bypasses {
if bypass.Type == metadata.BypassTypeCVE && bypass.Active {
// Check if bypass applies to this package
if bypass.AppliesTo == "" || bypass.AppliesTo == packageKey || bypass.AppliesTo == packageKeyNoVersion {
bypassedCVEs[strings.ToUpper(bypass.Target)] = bypass
}
}
}
// Enrich vulnerabilities with bypass information
enrichedVulns := make([]map[string]interface{}, 0, len(scanResult.Vulnerabilities))
severityCounts := make(map[string]int)
for _, vuln := range scanResult.Vulnerabilities {
bypassed := false
var bypassInfo map[string]interface{}
// Check if this CVE is bypassed
if bypass, ok := bypassedCVEs[strings.ToUpper(vuln.ID)]; ok {
bypassed = true
bypassInfo = map[string]interface{}{
"id": bypass.ID,
"reason": bypass.Reason,
"created_by": bypass.CreatedBy,
"expires_at": bypass.ExpiresAt,
}
} else {
// Count non-bypassed vulnerabilities by severity
severityCounts[strings.ToUpper(vuln.Severity)]++
}
enrichedVuln := map[string]interface{}{
"id": vuln.ID,
"severity": vuln.Severity,
"title": vuln.Title,
"description": vuln.Description,
"references": vuln.References,
"fixed_in": vuln.FixedIn,
"bypassed": bypassed,
}
if bypassed {
enrichedVuln["bypass"] = bypassInfo
}
enrichedVulns = append(enrichedVulns, enrichedVuln)
}
// Build response
response := map[string]interface{}{
"package": map[string]string{
"registry": registry,
"name": name,
"version": version,
},
"scanned": true,
"scanner": scanResult.Scanner,
"scanned_at": scanResult.ScannedAt,
"status": scanResult.Status,
"vulnerabilities": enrichedVulns,
"vulnerability_count": scanResult.VulnerabilityCount,
"severity_counts": map[string]int{
"critical": severityCounts["CRITICAL"],
"high": severityCounts["HIGH"],
"medium": severityCounts["MEDIUM"],
"low": severityCounts["LOW"],
},
"bypassed_count": len(scanResult.Vulnerabilities) - (severityCounts["CRITICAL"] + severityCounts["HIGH"] + severityCounts["MEDIUM"] + severityCounts["LOW"]),
}
errors.WriteJSONSimple(w, http.StatusOK, response)
}
+193
View File
@@ -0,0 +1,193 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"sync"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
// Manager handles authentication and authorization
type Manager struct {
keys map[string]*APIKey
mu sync.RWMutex
}
// APIKey represents an API key
type APIKey struct {
ID string
Name string
HashedKey string
Role Role
CreatedAt time.Time
ExpiresAt *time.Time
LastUsedAt time.Time
Permissions []Permission
}
// Role represents user role
type Role string
const (
RoleReadOnly Role = "readonly"
RoleReadWrite Role = "readwrite"
RoleAdmin Role = "admin"
)
// Permission represents a specific permission
type Permission string
const (
PermissionReadPackage Permission = "package:read"
PermissionWritePackage Permission = "package:write"
PermissionDeletePackage Permission = "package:delete"
PermissionViewStats Permission = "stats:view"
PermissionManageKeys Permission = "keys:manage"
PermissionManageSettings Permission = "settings:manage"
PermissionScanPackages Permission = "scan:execute"
PermissionManageBypasses Permission = "bypasses:manage"
)
// New creates a new authentication manager
func New() *Manager {
return &Manager{
keys: make(map[string]*APIKey),
}
}
// GenerateAPIKey generates a new API key
func (m *Manager) GenerateAPIKey(name string, role Role, expiresIn *time.Duration) (*APIKey, string, error) {
// Generate random key
keyBytes := make([]byte, 32)
if _, err := rand.Read(keyBytes); err != nil {
return nil, "", errors.Wrap(err, errors.ErrCodeInternalServer, "failed to generate random key")
}
rawKey := base64.URLEncoding.EncodeToString(keyBytes)
// Hash the key
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost)
if err != nil {
return nil, "", errors.Wrap(err, errors.ErrCodeInternalServer, "failed to hash key")
}
var expiresAt *time.Time
if expiresIn != nil {
t := time.Now().Add(*expiresIn)
expiresAt = &t
}
apiKey := &APIKey{
ID: generateID(),
Name: name,
HashedKey: string(hashedKey),
Role: role,
CreatedAt: time.Now(),
ExpiresAt: expiresAt,
Permissions: getPermissionsForRole(role),
}
m.mu.Lock()
m.keys[apiKey.ID] = apiKey
m.mu.Unlock()
return apiKey, rawKey, nil
}
// ValidateAPIKey validates an API key and returns the associated key object
func (m *Manager) ValidateAPIKey(ctx context.Context, rawKey string) (*APIKey, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, apiKey := range m.keys {
// Check if key is expired
if apiKey.ExpiresAt != nil && time.Now().After(*apiKey.ExpiresAt) {
continue
}
// Compare hashed key
if err := bcrypt.CompareHashAndPassword([]byte(apiKey.HashedKey), []byte(rawKey)); err == nil {
// Update last used
apiKey.LastUsedAt = time.Now()
return apiKey, nil
}
}
return nil, errors.New(errors.ErrCodeUnauthorized, "invalid API key")
}
// RevokeAPIKey revokes an API key
func (m *Manager) RevokeAPIKey(keyID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.keys[keyID]; !exists {
return errors.NotFound("API key not found")
}
delete(m.keys, keyID)
return nil
}
// ListAPIKeys lists all API keys
func (m *Manager) ListAPIKeys() []*APIKey {
m.mu.RLock()
defer m.mu.RUnlock()
keys := make([]*APIKey, 0, len(m.keys))
for _, key := range m.keys {
keys = append(keys, key)
}
return keys
}
// HasPermission checks if an API key has a specific permission
func (k *APIKey) HasPermission(permission Permission) bool {
for _, p := range k.Permissions {
if p == permission {
return true
}
}
return false
}
// getPermissionsForRole returns permissions for a role
func getPermissionsForRole(role Role) []Permission {
switch role {
case RoleReadOnly:
return []Permission{
PermissionReadPackage,
PermissionViewStats,
}
case RoleReadWrite:
return []Permission{
PermissionReadPackage,
PermissionWritePackage,
PermissionViewStats,
}
case RoleAdmin:
return []Permission{
PermissionReadPackage,
PermissionWritePackage,
PermissionDeletePackage,
PermissionViewStats,
PermissionManageKeys,
PermissionManageSettings,
PermissionScanPackages,
PermissionManageBypasses,
}
default:
return []Permission{}
}
}
// generateID generates a unique ID
func generateID() string {
b := make([]byte, 16)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
+360
View File
@@ -0,0 +1,360 @@
package cdn
import (
"crypto/md5"
"encoding/hex"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/rs/zerolog/log"
)
// CacheControl represents cache control directives
type CacheControl struct {
MaxAge int // max-age in seconds
SMaxAge int // s-maxage in seconds (for shared caches)
Public bool // public directive
Private bool // private directive
NoCache bool // no-cache directive
NoStore bool // no-store directive
MustRevalidate bool // must-revalidate directive
ProxyRevalidate bool // proxy-revalidate directive
Immutable bool // immutable directive
StaleWhileRevalidate int // stale-while-revalidate in seconds
}
// String returns the Cache-Control header value
func (cc CacheControl) String() string {
var parts []string
if cc.Public {
parts = append(parts, "public")
}
if cc.Private {
parts = append(parts, "private")
}
if cc.NoCache {
parts = append(parts, "no-cache")
}
if cc.NoStore {
parts = append(parts, "no-store")
}
if cc.MustRevalidate {
parts = append(parts, "must-revalidate")
}
if cc.ProxyRevalidate {
parts = append(parts, "proxy-revalidate")
}
if cc.Immutable {
parts = append(parts, "immutable")
}
if cc.MaxAge > 0 {
parts = append(parts, fmt.Sprintf("max-age=%d", cc.MaxAge))
}
if cc.SMaxAge > 0 {
parts = append(parts, fmt.Sprintf("s-maxage=%d", cc.SMaxAge))
}
if cc.StaleWhileRevalidate > 0 {
parts = append(parts, fmt.Sprintf("stale-while-revalidate=%d", cc.StaleWhileRevalidate))
}
result := ""
for i, part := range parts {
if i > 0 {
result += ", "
}
result += part
}
return result
}
// Middleware provides CDN and HTTP caching functionality
type Middleware struct {
defaultCacheControl CacheControl
enableETag bool
enableVary bool
}
// Config holds CDN middleware configuration
type Config struct {
DefaultCacheControl CacheControl
EnableETag bool
EnableVary bool
}
// NewMiddleware creates a new CDN middleware
func NewMiddleware(cfg Config) *Middleware {
return &Middleware{
defaultCacheControl: cfg.DefaultCacheControl,
enableETag: cfg.EnableETag,
enableVary: cfg.EnableVary,
}
}
// Handler wraps an HTTP handler with CDN caching support
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Wrap response writer to capture response for ETag generation
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
body: nil,
}
// Call next handler
next.ServeHTTP(rw, r)
// Apply caching headers if successful response
if rw.statusCode >= 200 && rw.statusCode < 300 {
m.applyCachingHeaders(rw, r)
}
})
}
// applyCachingHeaders applies appropriate caching headers to the response
func (m *Middleware) applyCachingHeaders(w *responseWriter, r *http.Request) {
// Set Cache-Control header if not already set
if w.Header().Get("Cache-Control") == "" {
w.Header().Set("Cache-Control", m.defaultCacheControl.String())
}
// Set Vary header for content negotiation
if m.enableVary {
m.setVaryHeader(w, r)
}
// Generate and check ETag if enabled
if m.enableETag && w.body != nil {
m.handleETag(w, r)
}
}
// setVaryHeader sets the Vary header based on request
func (m *Middleware) setVaryHeader(w *responseWriter, r *http.Request) {
varies := []string{}
// Vary on Accept-Encoding for compression
if r.Header.Get("Accept-Encoding") != "" {
varies = append(varies, "Accept-Encoding")
}
// Vary on Authorization for authenticated requests
if r.Header.Get("Authorization") != "" {
varies = append(varies, "Authorization")
}
// Vary on Accept for content negotiation
if r.Header.Get("Accept") != "" {
varies = append(varies, "Accept")
}
if len(varies) > 0 {
varyHeader := ""
for i, v := range varies {
if i > 0 {
varyHeader += ", "
}
varyHeader += v
}
w.Header().Set("Vary", varyHeader)
}
}
// handleETag generates ETag and handles conditional requests
func (m *Middleware) handleETag(w *responseWriter, r *http.Request) {
// Generate ETag from response body
etag := m.generateETag(w.body)
w.Header().Set("ETag", etag)
// Handle conditional requests
if ifNoneMatch := r.Header.Get("If-None-Match"); ifNoneMatch != "" {
if ifNoneMatch == etag {
// ETag matches - return 304 Not Modified
w.WriteHeader(http.StatusNotModified)
w.body = nil // Clear body for 304 response
log.Debug().
Str("path", r.URL.Path).
Str("etag", etag).
Msg("ETag match - returning 304 Not Modified")
return
}
}
// Handle If-Modified-Since
if lastModified := w.Header().Get("Last-Modified"); lastModified != "" {
if ifModifiedSince := r.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
lastModTime, err := http.ParseTime(lastModified)
if err == nil {
ifModTime, err := http.ParseTime(ifModifiedSince)
if err == nil && !lastModTime.After(ifModTime) {
// Not modified - return 304
w.WriteHeader(http.StatusNotModified)
w.body = nil
log.Debug().
Str("path", r.URL.Path).
Time("last_modified", lastModTime).
Msg("Not modified - returning 304")
return
}
}
}
}
}
// generateETag creates an ETag for HTTP caching
// NOTE: MD5 is used for content fingerprinting (ETag), not cryptographic security
func (m *Middleware) generateETag(body []byte) string {
if body == nil {
return ""
}
hash := md5.Sum(body)
return `"` + hex.EncodeToString(hash[:]) + `"`
}
// SetLastModified sets the Last-Modified header
func SetLastModified(w http.ResponseWriter, t time.Time) {
w.Header().Set("Last-Modified", t.UTC().Format(http.TimeFormat))
}
// SetCacheControl sets a custom Cache-Control header
func SetCacheControl(w http.ResponseWriter, cc CacheControl) {
w.Header().Set("Cache-Control", cc.String())
}
// SetNoCache sets headers to prevent caching
func SetNoCache(w http.ResponseWriter) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
}
// SetImmutable sets headers for immutable content (content-addressed files)
func SetImmutable(w http.ResponseWriter, maxAge int) {
cc := CacheControl{
Public: true,
MaxAge: maxAge,
Immutable: true,
}
w.Header().Set("Cache-Control", cc.String())
}
// responseWriter wraps http.ResponseWriter to capture response
type responseWriter struct {
http.ResponseWriter
statusCode int
body []byte
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
// Capture body for ETag generation
if rw.body == nil {
rw.body = make([]byte, 0, len(b))
}
rw.body = append(rw.body, b...)
return rw.ResponseWriter.Write(b)
}
// HandleRange handles HTTP Range requests for partial content
func HandleRange(w http.ResponseWriter, r *http.Request, content io.ReadSeeker, size int64, modTime time.Time) error {
// Set Last-Modified header
SetLastModified(w, modTime)
// Check for Range header
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
// No range request - serve full content
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(http.StatusOK)
_, err := io.Copy(w, content)
return err
}
// Parse range header (simplified - only handles single range)
// Format: bytes=start-end
var start, end int64
n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end)
if err != nil || n != 2 {
// Invalid range - serve full content
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(http.StatusOK)
_, err := io.Copy(w, content)
return err
}
// Validate range
if start < 0 || start >= size || end < start || end >= size {
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return nil
}
// Seek to start position
if _, err := content.Seek(start, io.SeekStart); err != nil {
return err
}
// Calculate content length
contentLength := end - start + 1
// Set headers for partial content
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, size))
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
w.Header().Set("Accept-Ranges", "bytes")
w.WriteHeader(http.StatusPartialContent)
// Copy range to response
_, err = io.CopyN(w, content, contentLength)
return err
}
// DefaultCacheControl returns sensible defaults for different content types
func DefaultCacheControl(contentType string, versioned bool) CacheControl {
if versioned {
// Content-addressed or versioned resources can be cached forever
return CacheControl{
Public: true,
MaxAge: 31536000, // 1 year
Immutable: true,
}
}
// Default caching based on content type
switch contentType {
case "application/json":
return CacheControl{
Public: true,
MaxAge: 3600, // 1 hour
SMaxAge: 7200, // 2 hours for shared caches
}
case "application/octet-stream", "application/x-gzip", "application/zip":
// Binary packages
return CacheControl{
Public: true,
MaxAge: 86400, // 1 day
SMaxAge: 604800, // 1 week for shared caches
}
case "text/html":
// HTML should revalidate
return CacheControl{
Public: true,
MaxAge: 0,
MustRevalidate: true,
}
default:
return CacheControl{
Public: true,
MaxAge: 3600, // 1 hour default
SMaxAge: 7200,
}
}
}
+395
View File
@@ -0,0 +1,395 @@
package config
import (
"fmt"
"time"
)
// Config is the main configuration struct
type Config struct {
Server ServerConfig `mapstructure:"server" json:"server"`
Storage StorageConfig `mapstructure:"storage" json:"storage"`
Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"`
Cache CacheConfig `mapstructure:"cache" json:"cache"`
Security SecurityConfig `mapstructure:"security" json:"security"`
Auth AuthConfig `mapstructure:"auth" json:"auth"`
Network NetworkConfig `mapstructure:"network" json:"network"`
Logging LoggingConfig `mapstructure:"logging" json:"logging"`
Handlers HandlersConfig `mapstructure:"handlers" json:"handlers"`
}
// ServerConfig contains HTTP server configuration
type ServerConfig struct {
Host string `mapstructure:"host" json:"host"`
Port int `mapstructure:"port" json:"port"`
ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"`
IdleTimeout time.Duration `mapstructure:"idle_timeout" json:"idle_timeout"`
TLS TLSConfig `mapstructure:"tls" json:"tls"`
}
// TLSConfig contains TLS/HTTPS configuration
type TLSConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
CertFile string `mapstructure:"cert_file" json:"cert_file"`
KeyFile string `mapstructure:"key_file" json:"key_file"`
}
// StorageConfig contains storage backend configuration
type StorageConfig struct {
Backend string `mapstructure:"backend" json:"backend"` // filesystem, s3, smb, nfs
Path string `mapstructure:"path" json:"path"`
Filesystem FilesystemConfig `mapstructure:"filesystem" json:"filesystem"`
S3 S3Config `mapstructure:"s3" json:"s3"`
SMB SMBConfig `mapstructure:"smb" json:"smb"`
Options map[string]interface{} `mapstructure:"options" json:"options"`
}
// FilesystemConfig contains local filesystem storage configuration
type FilesystemConfig struct {
BasePath string `mapstructure:"base_path" json:"base_path"`
}
// S3Config contains S3-compatible storage configuration
type S3Config struct {
Endpoint string `mapstructure:"endpoint" json:"endpoint"`
Region string `mapstructure:"region" json:"region"`
Bucket string `mapstructure:"bucket" json:"bucket"`
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"`
SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // Don't serialize secrets
UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"`
}
// SMBConfig contains SMB/CIFS storage configuration
type SMBConfig struct {
Host string `mapstructure:"host" json:"host"`
Share string `mapstructure:"share" json:"share"`
Username string `mapstructure:"username" json:"username"`
Password string `mapstructure:"password" json:"-"` // Don't serialize secrets
Domain string `mapstructure:"domain" json:"domain"`
}
// MetadataConfig contains metadata store configuration
type MetadataConfig struct {
Backend string `mapstructure:"backend" json:"backend"` // sqlite, postgresql, file
Connection string `mapstructure:"connection" json:"connection"`
SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"`
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
}
// SQLiteConfig contains SQLite-specific configuration
type SQLiteConfig struct {
Path string `mapstructure:"path" json:"path"`
WALMode bool `mapstructure:"wal_mode" json:"wal_mode"`
}
// PostgreSQLConfig contains PostgreSQL-specific configuration
type PostgreSQLConfig struct {
Host string `mapstructure:"host" json:"host"`
Port int `mapstructure:"port" json:"port"`
Database string `mapstructure:"database" json:"database"`
User string `mapstructure:"user" json:"user"`
Password string `mapstructure:"password" json:"-"` // Don't serialize secrets
SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"`
}
// CacheConfig contains cache management configuration
type CacheConfig struct {
DefaultTTL time.Duration `mapstructure:"default_ttl" json:"default_ttl"`
CleanupInterval time.Duration `mapstructure:"cleanup_interval" json:"cleanup_interval"`
MaxSizeBytes int64 `mapstructure:"max_size_bytes" json:"max_size_bytes"`
PerProjectQuota int64 `mapstructure:"per_project_quota" json:"per_project_quota"`
TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"` // Per ecosystem
}
// SecurityConfig contains security scanning configuration
type SecurityConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"` // Scan packages on first download
RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"` // How often to re-scan (e.g., 24h, 168h for weekly)
BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"` // none, low, medium, high, critical
BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"` // Max vulns per severity before blocking
UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"` // Update vulnerability databases on startup
AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"` // Packages that bypass security checks (format: "registry/name@version" or "registry/name")
IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"` // CVE IDs to ignore globally (e.g., "CVE-2021-23337")
Scanners ScannersConfig `mapstructure:"scanners" json:"scanners"`
}
// VulnerabilityThresholds defines max allowed vulnerabilities per severity
type VulnerabilityThresholds struct {
Critical int `mapstructure:"critical" json:"critical"` // Max critical vulns (0 = block any)
High int `mapstructure:"high" json:"high"` // Max high vulns
Medium int `mapstructure:"medium" json:"medium"` // Max medium vulns
Low int `mapstructure:"low" json:"low"` // Max low vulns (-1 = unlimited)
}
// ScannersConfig contains individual scanner configurations
type ScannersConfig struct {
Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"`
OSV OSVConfig `mapstructure:"osv" json:"osv"`
Static StaticConfig `mapstructure:"static" json:"static"`
}
// TrivyConfig contains Trivy scanner configuration
type TrivyConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
CacheDB string `mapstructure:"cache_db" json:"cache_db"`
}
// OSVConfig contains OSV scanner configuration
type OSVConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
APIURL string `mapstructure:"api_url" json:"api_url"`
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
}
// StaticConfig contains static analysis configuration
type StaticConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
MaxPackageSize int64 `mapstructure:"max_package_size" json:"max_package_size"`
CheckChecksums bool `mapstructure:"check_checksums" json:"check_checksums"`
BlockSuspicious bool `mapstructure:"block_suspicious" json:"block_suspicious"`
AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"`
}
// AuthConfig contains authentication configuration
type AuthConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
KeyExpiration time.Duration `mapstructure:"key_expiration" json:"key_expiration"`
BcryptCost int `mapstructure:"bcrypt_cost" json:"bcrypt_cost"`
AuditLog bool `mapstructure:"audit_log" json:"audit_log"`
}
// NetworkConfig contains network resilience configuration
type NetworkConfig struct {
ConnectTimeout time.Duration `mapstructure:"connect_timeout" json:"connect_timeout"`
ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"`
MaxIdleConns int `mapstructure:"max_idle_conns" json:"max_idle_conns"`
MaxConnsPerHost int `mapstructure:"max_conns_per_host" json:"max_conns_per_host"`
RateLimit RateLimitConfig `mapstructure:"rate_limit" json:"rate_limit"`
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker" json:"circuit_breaker"`
Retry RetryConfig `mapstructure:"retry" json:"retry"`
}
// RateLimitConfig contains rate limiting configuration
type RateLimitConfig struct {
PerAPIKey int `mapstructure:"per_api_key" json:"per_api_key"`
PerIP int `mapstructure:"per_ip" json:"per_ip"`
BurstSize int `mapstructure:"burst_size" json:"burst_size"`
}
// CircuitBreakerConfig contains circuit breaker configuration
type CircuitBreakerConfig struct {
Threshold int `mapstructure:"threshold" json:"threshold"`
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
ResetInterval time.Duration `mapstructure:"reset_interval" json:"reset_interval"`
}
// RetryConfig contains retry policy configuration
type RetryConfig struct {
MaxAttempts int `mapstructure:"max_attempts" json:"max_attempts"`
InitialBackoff time.Duration `mapstructure:"initial_backoff" json:"initial_backoff"`
MaxBackoff time.Duration `mapstructure:"max_backoff" json:"max_backoff"`
}
// LoggingConfig contains logging configuration
type LoggingConfig struct {
Level string `mapstructure:"level" json:"level"` // debug, info, warn, error
Format string `mapstructure:"format" json:"format"` // json, pretty
}
// HandlersConfig contains package manager handler configurations
type HandlersConfig struct {
Go GoHandlerConfig `mapstructure:"go" json:"go"`
NPM NPMHandlerConfig `mapstructure:"npm" json:"npm"`
PyPI PyPIHandlerConfig `mapstructure:"pypi" json:"pypi"`
}
// GoHandlerConfig contains Go proxy configuration
type GoHandlerConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
UpstreamProxy string `mapstructure:"upstream_proxy" json:"upstream_proxy"`
ChecksumDB string `mapstructure:"checksum_db" json:"checksum_db"`
VerifyChecksums bool `mapstructure:"verify_checksums" json:"verify_checksums"`
}
// NPMHandlerConfig contains NPM registry configuration
type NPMHandlerConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
UpstreamRegistry string `mapstructure:"upstream_registry" json:"upstream_registry"`
}
// PyPIHandlerConfig contains PyPI configuration
type PyPIHandlerConfig struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
UpstreamURL string `mapstructure:"upstream_url" json:"upstream_url"`
SimpleAPIURL string `mapstructure:"simple_api_url" json:"simple_api_url"`
}
// Default returns a configuration with sensible defaults
func Default() *Config {
return &Config{
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
IdleTimeout: 2 * time.Minute,
TLS: TLSConfig{
Enabled: false,
},
},
Storage: StorageConfig{
Backend: "filesystem",
Path: "/var/cache/gohoarder",
Filesystem: FilesystemConfig{
BasePath: "/var/cache/gohoarder",
},
},
Metadata: MetadataConfig{
Backend: "sqlite",
Connection: "file:gohoarder.db?cache=shared&mode=rwc",
SQLite: SQLiteConfig{
Path: "gohoarder.db",
WALMode: true,
},
},
Cache: CacheConfig{
DefaultTTL: 7 * 24 * time.Hour,
CleanupInterval: 1 * time.Hour,
MaxSizeBytes: 500 * 1024 * 1024 * 1024, // 500GB
PerProjectQuota: 50 * 1024 * 1024 * 1024, // 50GB
TTLOverrides: map[string]time.Duration{
"npm": 7 * 24 * time.Hour,
"pip": 7 * 24 * time.Hour,
"go": 7 * 24 * time.Hour,
},
},
Security: SecurityConfig{
Enabled: false,
BlockOnSeverity: "high",
Scanners: ScannersConfig{
Trivy: TrivyConfig{
Enabled: false,
Timeout: 5 * time.Minute,
CacheDB: "/var/lib/trivy",
},
OSV: OSVConfig{
Enabled: false,
APIURL: "https://api.osv.dev",
Timeout: 30 * time.Second,
},
Static: StaticConfig{
Enabled: true,
MaxPackageSize: 2 * 1024 * 1024 * 1024, // 2GB
CheckChecksums: true,
BlockSuspicious: false,
},
},
},
Auth: AuthConfig{
Enabled: true,
KeyExpiration: 0, // Never expire
BcryptCost: 10,
AuditLog: true,
},
Network: NetworkConfig{
ConnectTimeout: 10 * time.Second,
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
MaxIdleConns: 100,
MaxConnsPerHost: 10,
RateLimit: RateLimitConfig{
PerAPIKey: 1000,
PerIP: 100,
BurstSize: 50,
},
CircuitBreaker: CircuitBreakerConfig{
Threshold: 5,
Timeout: 30 * time.Second,
ResetInterval: 60 * time.Second,
},
Retry: RetryConfig{
MaxAttempts: 3,
InitialBackoff: 1 * time.Second,
MaxBackoff: 30 * time.Second,
},
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
},
Handlers: HandlersConfig{
Go: GoHandlerConfig{
Enabled: true,
UpstreamProxy: "https://proxy.golang.org",
ChecksumDB: "https://sum.golang.org",
VerifyChecksums: true,
},
NPM: NPMHandlerConfig{
Enabled: true,
UpstreamRegistry: "https://registry.npmjs.org",
},
PyPI: PyPIHandlerConfig{
Enabled: true,
UpstreamURL: "https://pypi.org",
SimpleAPIURL: "https://pypi.org/simple",
},
},
}
}
// Validate validates the configuration
func (c *Config) Validate() error {
// Validate server
if c.Server.Port < 1 || c.Server.Port > 65535 {
return fmt.Errorf("server.port must be between 1 and 65535, got %d", c.Server.Port)
}
// Validate storage backend
validStorageBackends := map[string]bool{"filesystem": true, "s3": true, "smb": true, "nfs": true}
if !validStorageBackends[c.Storage.Backend] {
return fmt.Errorf("storage.backend must be one of: filesystem, s3, smb, nfs; got %s", c.Storage.Backend)
}
// Validate metadata backend
validMetadataBackends := map[string]bool{"sqlite": true, "postgresql": true, "file": true}
if !validMetadataBackends[c.Metadata.Backend] {
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, file; got %s", c.Metadata.Backend)
}
// Validate cache
if c.Cache.DefaultTTL < 0 {
return fmt.Errorf("cache.default_ttl cannot be negative")
}
if c.Cache.MaxSizeBytes < 0 {
return fmt.Errorf("cache.max_size_bytes cannot be negative")
}
// Validate security
validSeverities := map[string]bool{"none": true, "low": true, "medium": true, "high": true, "critical": true}
if !validSeverities[c.Security.BlockOnSeverity] {
return fmt.Errorf("security.block_on_severity must be one of: none, low, medium, high, critical; got %s", c.Security.BlockOnSeverity)
}
// Validate logging level
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
if !validLevels[c.Logging.Level] {
return fmt.Errorf("logging.level must be one of: debug, info, warn, error; got %s", c.Logging.Level)
}
// Validate logging format
validFormats := map[string]bool{"json": true, "pretty": true}
if !validFormats[c.Logging.Format] {
return fmt.Errorf("logging.format must be one of: json, pretty; got %s", c.Logging.Format)
}
// Validate auth
if c.Auth.BcryptCost < 4 || c.Auth.BcryptCost > 31 {
return fmt.Errorf("auth.bcrypt_cost must be between 4 and 31, got %d", c.Auth.BcryptCost)
}
return nil
}
+383
View File
@@ -0,0 +1,383 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type ConfigTestSuite struct {
suite.Suite
tempDir string
}
func TestConfigTestSuite(t *testing.T) {
suite.Run(t, new(ConfigTestSuite))
}
func (s *ConfigTestSuite) SetupTest() {
var err error
s.tempDir, err = os.MkdirTemp("", "gohoarder-config-test-*")
s.Require().NoError(err)
}
func (s *ConfigTestSuite) TearDownTest() {
os.RemoveAll(s.tempDir)
}
func (s *ConfigTestSuite) TestDefault() {
cfg := Default()
s.NotNil(cfg)
s.Equal("0.0.0.0", cfg.Server.Host)
s.Equal(8080, cfg.Server.Port)
s.Equal("filesystem", cfg.Storage.Backend)
s.Equal("sqlite", cfg.Metadata.Backend)
s.NoError(cfg.Validate())
}
func (s *ConfigTestSuite) TestValidate() {
tests := []struct {
name string
modify func(*Config)
expectError bool
errorSubstr string
}{
{
name: "valid_config",
modify: func(c *Config) {},
expectError: false,
},
{
name: "invalid_port_too_low",
modify: func(c *Config) {
c.Server.Port = 0
},
expectError: true,
errorSubstr: "port must be between",
},
{
name: "invalid_port_too_high",
modify: func(c *Config) {
c.Server.Port = 70000
},
expectError: true,
errorSubstr: "port must be between",
},
{
name: "invalid_storage_backend",
modify: func(c *Config) {
c.Storage.Backend = "invalid"
},
expectError: true,
errorSubstr: "storage.backend must be one of",
},
{
name: "invalid_metadata_backend",
modify: func(c *Config) {
c.Metadata.Backend = "mongodb"
},
expectError: true,
errorSubstr: "metadata.backend must be one of",
},
{
name: "negative_ttl",
modify: func(c *Config) {
c.Cache.DefaultTTL = -1 * time.Hour
},
expectError: true,
errorSubstr: "cannot be negative",
},
{
name: "negative_cache_size",
modify: func(c *Config) {
c.Cache.MaxSizeBytes = -100
},
expectError: true,
errorSubstr: "cannot be negative",
},
{
name: "invalid_severity",
modify: func(c *Config) {
c.Security.BlockOnSeverity = "super-high"
},
expectError: true,
errorSubstr: "block_on_severity must be one of",
},
{
name: "invalid_log_level",
modify: func(c *Config) {
c.Logging.Level = "verbose"
},
expectError: true,
errorSubstr: "logging.level must be one of",
},
{
name: "invalid_log_format",
modify: func(c *Config) {
c.Logging.Format = "xml"
},
expectError: true,
errorSubstr: "logging.format must be one of",
},
{
name: "invalid_bcrypt_cost_too_low",
modify: func(c *Config) {
c.Auth.BcryptCost = 3
},
expectError: true,
errorSubstr: "bcrypt_cost must be between",
},
{
name: "invalid_bcrypt_cost_too_high",
modify: func(c *Config) {
c.Auth.BcryptCost = 32
},
expectError: true,
errorSubstr: "bcrypt_cost must be between",
},
{
name: "valid_s3_backend",
modify: func(c *Config) {
c.Storage.Backend = "s3"
},
expectError: false,
},
{
name: "valid_postgresql_backend",
modify: func(c *Config) {
c.Metadata.Backend = "postgresql"
},
expectError: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
cfg := Default()
tt.modify(cfg)
err := cfg.Validate()
if tt.expectError {
s.Error(err)
if tt.errorSubstr != "" {
s.Contains(err.Error(), tt.errorSubstr)
}
} else {
s.NoError(err)
}
})
}
}
func (s *ConfigTestSuite) TestLoad() {
tests := []struct {
name string
configYAML string
envVars map[string]string
expectError bool
validate func(*Config)
}{
{
name: "valid_yaml_config",
configYAML: `
server:
host: 127.0.0.1
port: 9000
storage:
backend: filesystem
path: /custom/path
logging:
level: debug
format: pretty
`,
expectError: false,
validate: func(cfg *Config) {
s.Equal("127.0.0.1", cfg.Server.Host)
s.Equal(9000, cfg.Server.Port)
s.Equal("/custom/path", cfg.Storage.Path)
s.Equal("debug", cfg.Logging.Level)
s.Equal("pretty", cfg.Logging.Format)
},
},
{
name: "env_var_override",
configYAML: `
server:
port: 8080
`,
envVars: map[string]string{
"GOHOARDER_SERVER_PORT": "9090",
},
expectError: false,
validate: func(cfg *Config) {
s.Equal(9090, cfg.Server.Port)
},
},
{
name: "invalid_yaml",
configYAML: `
server: [invalid
`,
expectError: true,
},
{
name: "validation_failure",
configYAML: `
server:
port: 100000
`,
expectError: true,
},
{
name: "complete_config",
configYAML: `
server:
host: 0.0.0.0
port: 8080
read_timeout: 300s
write_timeout: 300s
storage:
backend: s3
s3:
endpoint: s3.amazonaws.com
region: us-east-1
bucket: my-cache
access_key_id: AKIAIOSFODNN7EXAMPLE
secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
metadata:
backend: postgresql
postgresql:
host: localhost
port: 5432
database: gohoarder
user: postgres
password: secret
ssl_mode: require
cache:
default_ttl: 168h
max_size_bytes: 536870912000
security:
enabled: true
block_on_severity: high
scanners:
trivy:
enabled: true
timeout: 300s
auth:
enabled: true
bcrypt_cost: 12
`,
expectError: false,
validate: func(cfg *Config) {
s.Equal("s3", cfg.Storage.Backend)
s.Equal("s3.amazonaws.com", cfg.Storage.S3.Endpoint)
s.Equal("postgresql", cfg.Metadata.Backend)
s.Equal("localhost", cfg.Metadata.PostgreSQL.Host)
s.True(cfg.Security.Enabled)
s.Equal(12, cfg.Auth.BcryptCost)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// Write config file
configPath := filepath.Join(s.tempDir, "config.yaml")
err := os.WriteFile(configPath, []byte(tt.configYAML), 0644)
s.Require().NoError(err)
// Set environment variables
for k, v := range tt.envVars {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
// Load config
cfg, err := Load(configPath)
if tt.expectError {
s.Error(err)
} else {
s.NoError(err)
s.NotNil(cfg)
if tt.validate != nil {
tt.validate(cfg)
}
}
})
}
}
func (s *ConfigTestSuite) TestLoadMissingFile() {
// Should return error when file explicitly specified but not found
cfg, err := Load("/nonexistent/path/to/config.yaml")
s.Error(err)
s.Nil(cfg)
}
func (s *ConfigTestSuite) TestLoadWithDefaults() {
// Invalid config path should return defaults
cfg := LoadWithDefaults("/invalid/path/config.yaml")
s.NotNil(cfg)
s.Equal(8080, cfg.Server.Port)
}
// Benchmark tests
func BenchmarkDefault(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = Default()
}
}
func BenchmarkValidate(b *testing.B) {
cfg := Default()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cfg.Validate()
}
}
// Table-driven edge cases
func TestConfigEdgeCases(t *testing.T) {
tests := []struct {
name string
config *Config
valid bool
}{
{
name: "minimal_config",
config: &Config{Server: ServerConfig{Port: 8080}, Storage: StorageConfig{Backend: "filesystem"}, Metadata: MetadataConfig{Backend: "sqlite"}, Logging: LoggingConfig{Level: "info", Format: "json"}, Security: SecurityConfig{BlockOnSeverity: "high"}, Auth: AuthConfig{BcryptCost: 10}},
valid: true,
},
{
name: "zero_ttl",
config: func() *Config { c := Default(); c.Cache.DefaultTTL = 0; return c }(),
valid: true, // Zero is valid (no caching)
},
{
name: "max_bcrypt_cost",
config: func() *Config { c := Default(); c.Auth.BcryptCost = 31; return c }(),
valid: true,
},
{
name: "min_bcrypt_cost",
config: func() *Config { c := Default(); c.Auth.BcryptCost = 4; return c }(),
valid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.valid {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
+62
View File
@@ -0,0 +1,62 @@
package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
// Load loads configuration from file and environment variables
func Load(configPath string) (*Config, error) {
v := viper.New()
// Set config file if provided
if configPath != "" {
v.SetConfigFile(configPath)
} else {
// Look for config.yaml in current directory and /etc/gohoarder
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("/etc/gohoarder")
v.AddConfigPath("$HOME/.gohoarder")
}
// Set environment variable prefix
v.SetEnvPrefix("GOHOARDER")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Read config file
if err := v.ReadInConfig(); err != nil {
// If no config file found, use defaults
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
}
// Start with defaults
cfg := Default()
// Unmarshal into config struct
if err := v.Unmarshal(cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Validate configuration
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("config validation failed: %w", err)
}
return cfg, nil
}
// LoadWithDefaults loads configuration or returns defaults on error
func LoadWithDefaults(configPath string) *Config {
cfg, err := Load(configPath)
if err != nil {
return Default()
}
return cfg
}
+68
View File
@@ -0,0 +1,68 @@
package errors
// Error codes following consistent naming convention
const (
// Client errors (4xx)
ErrCodeBadRequest = "BAD_REQUEST"
ErrCodeUnauthorized = "UNAUTHORIZED"
ErrCodeForbidden = "FORBIDDEN"
ErrCodeNotFound = "NOT_FOUND"
ErrCodeRateLimited = "RATE_LIMITED"
ErrCodePayloadTooLarge = "PAYLOAD_TOO_LARGE"
ErrCodeInvalidAPIKey = "INVALID_API_KEY"
ErrCodeQuotaExceeded = "QUOTA_EXCEEDED"
ErrCodeConflict = "CONFLICT"
ErrCodeInvalidConfig = "INVALID_CONFIG"
// Package-specific errors
ErrCodePackageNotFound = "PACKAGE_NOT_FOUND"
ErrCodeVersionNotFound = "VERSION_NOT_FOUND"
ErrCodeChecksumMismatch = "CHECKSUM_MISMATCH"
ErrCodeCorruptPackage = "CORRUPT_PACKAGE"
ErrCodeSecurityBlocked = "SECURITY_BLOCKED"
ErrCodeSecurityViolation = "SECURITY_VIOLATION" // Package has vulnerabilities exceeding thresholds
ErrCodeUpstreamError = "UPSTREAM_ERROR"
// Server errors (5xx)
ErrCodeInternalServer = "INTERNAL_SERVER_ERROR"
ErrCodeStorageFailure = "STORAGE_FAILURE"
ErrCodeUpstreamFailure = "UPSTREAM_FAILURE"
ErrCodeDatabaseFailure = "DATABASE_FAILURE"
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
ErrCodeCircuitOpen = "CIRCUIT_OPEN"
)
// HTTPStatusCode maps error codes to HTTP status codes
var HTTPStatusCode = map[string]int{
ErrCodeBadRequest: 400,
ErrCodeUnauthorized: 401,
ErrCodeForbidden: 403,
ErrCodeNotFound: 404,
ErrCodeConflict: 409,
ErrCodeRateLimited: 429,
ErrCodePayloadTooLarge: 413,
ErrCodeInvalidAPIKey: 401,
ErrCodeQuotaExceeded: 429,
ErrCodeInvalidConfig: 400,
ErrCodePackageNotFound: 404,
ErrCodeVersionNotFound: 404,
ErrCodeChecksumMismatch: 422,
ErrCodeCorruptPackage: 422,
ErrCodeSecurityBlocked: 403,
ErrCodeSecurityViolation: 426, // Upgrade Required
ErrCodeUpstreamError: 502,
ErrCodeInternalServer: 500,
ErrCodeStorageFailure: 500,
ErrCodeUpstreamFailure: 502,
ErrCodeDatabaseFailure: 500,
ErrCodeServiceUnavailable: 503,
ErrCodeCircuitOpen: 503,
}
// GetHTTPStatus returns the HTTP status code for an error code
func GetHTTPStatus(code string) int {
if status, ok := HTTPStatusCode[code]; ok {
return status
}
return 500 // Default to internal server error
}
+115
View File
@@ -0,0 +1,115 @@
package errors
import (
"fmt"
)
// Error represents a structured error with code and details
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
Details interface{} `json:"details,omitempty"`
Trace []string `json:"trace,omitempty"`
Cause error `json:"-"` // Internal cause, not serialized
}
// Error implements the error interface
func (e *Error) Error() string {
if e.Cause != nil {
return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Cause)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// Unwrap returns the cause for errors.Is/As support
func (e *Error) Unwrap() error {
return e.Cause
}
// New creates a new error with the given code and message
func New(code, message string) *Error {
return &Error{
Code: code,
Message: message,
}
}
// Newf creates a new error with formatted message
func Newf(code, format string, args ...interface{}) *Error {
return &Error{
Code: code,
Message: fmt.Sprintf(format, args...),
}
}
// WithDetails adds details to the error
func (e *Error) WithDetails(details interface{}) *Error {
e.Details = details
return e
}
// WithTrace adds stack trace to the error
func (e *Error) WithTrace(trace []string) *Error {
e.Trace = trace
return e
}
// WithCause adds an underlying cause to the error
func (e *Error) WithCause(cause error) *Error {
e.Cause = cause
return e
}
// Wrap wraps an existing error with a new code and message
func Wrap(err error, code, message string) *Error {
return &Error{
Code: code,
Message: message,
Cause: err,
}
}
// Wrapf wraps an existing error with formatted message
func Wrapf(err error, code, format string, args ...interface{}) *Error {
return &Error{
Code: code,
Message: fmt.Sprintf(format, args...),
Cause: err,
}
}
// Common error constructors
func BadRequest(message string) *Error {
return New(ErrCodeBadRequest, message)
}
func Unauthorized(message string) *Error {
return New(ErrCodeUnauthorized, message)
}
func Forbidden(message string) *Error {
return New(ErrCodeForbidden, message)
}
func NotFound(message string) *Error {
return New(ErrCodeNotFound, message)
}
func InternalServer(message string) *Error {
return New(ErrCodeInternalServer, message)
}
func PackageNotFound(name, version string) *Error {
return New(ErrCodePackageNotFound, fmt.Sprintf("Package %s@%s not found", name, version)).
WithDetails(map[string]string{
"package": name,
"version": version,
})
}
func QuotaExceeded(limit int64) *Error {
return New(ErrCodeQuotaExceeded, "Storage quota exceeded").
WithDetails(map[string]interface{}{
"limit_bytes": limit,
})
}
+305
View File
@@ -0,0 +1,305 @@
package errors
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type ErrorsTestSuite struct {
suite.Suite
}
func TestErrorsTestSuite(t *testing.T) {
suite.Run(t, new(ErrorsTestSuite))
}
func (s *ErrorsTestSuite) TestNew() {
tests := []struct {
name string
code string
message string
}{
{
name: "simple_error",
code: ErrCodeNotFound,
message: "Resource not found",
},
{
name: "empty_message",
code: ErrCodeBadRequest,
message: "",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
err := New(tt.code, tt.message)
s.Equal(tt.code, err.Code)
s.Equal(tt.message, err.Message)
s.Nil(err.Details)
s.Nil(err.Trace)
s.Nil(err.Cause)
})
}
}
func (s *ErrorsTestSuite) TestNewf() {
tests := []struct {
name string
code string
format string
args []interface{}
expected string
}{
{
name: "formatted_message",
code: ErrCodePackageNotFound,
format: "Package %s@%s not found",
args: []interface{}{"react", "18.2.0"},
expected: "Package react@18.2.0 not found",
},
{
name: "no_args",
code: ErrCodeInternalServer,
format: "Internal error",
args: []interface{}{},
expected: "Internal error",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
err := Newf(tt.code, tt.format, tt.args...)
s.Equal(tt.code, err.Code)
s.Equal(tt.expected, err.Message)
})
}
}
func (s *ErrorsTestSuite) TestWithDetails() {
tests := []struct {
name string
details interface{}
}{
{
name: "map_details",
details: map[string]string{"key": "value"},
},
{
name: "string_details",
details: "some details",
},
{
name: "nil_details",
details: nil,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
err := New(ErrCodeBadRequest, "test").WithDetails(tt.details)
s.Equal(tt.details, err.Details)
})
}
}
func (s *ErrorsTestSuite) TestWithTrace() {
trace := []string{"file1.go:10", "file2.go:20"}
err := New(ErrCodeInternalServer, "test").WithTrace(trace)
s.Equal(trace, err.Trace)
}
func (s *ErrorsTestSuite) TestWithCause() {
cause := errors.New("underlying error")
err := New(ErrCodeStorageFailure, "test").WithCause(cause)
s.Equal(cause, err.Cause)
s.Contains(err.Error(), "underlying error")
}
func (s *ErrorsTestSuite) TestWrap() {
cause := errors.New("original error")
wrapped := Wrap(cause, ErrCodeDatabaseFailure, "database connection failed")
s.Equal(ErrCodeDatabaseFailure, wrapped.Code)
s.Equal("database connection failed", wrapped.Message)
s.Equal(cause, wrapped.Cause)
s.True(errors.Is(wrapped, cause))
}
func (s *ErrorsTestSuite) TestWrapf() {
cause := errors.New("connection refused")
wrapped := Wrapf(cause, ErrCodeUpstreamFailure, "failed to connect to %s", "registry.npmjs.org")
s.Equal(ErrCodeUpstreamFailure, wrapped.Code)
s.Equal("failed to connect to registry.npmjs.org", wrapped.Message)
s.Equal(cause, wrapped.Cause)
}
func (s *ErrorsTestSuite) TestErrorString() {
tests := []struct {
name string
err *Error
expected string
}{
{
name: "error_without_cause",
err: New(ErrCodeNotFound, "not found"),
expected: "NOT_FOUND: not found",
},
{
name: "error_with_cause",
err: Wrap(errors.New("io error"), ErrCodeStorageFailure, "storage failed"),
expected: "STORAGE_FAILURE: storage failed (caused by: io error)",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
s.Equal(tt.expected, tt.err.Error())
})
}
}
func (s *ErrorsTestSuite) TestCommonConstructors() {
tests := []struct {
name string
fn func() *Error
wantCode string
}{
{
name: "bad_request",
fn: func() *Error { return BadRequest("invalid input") },
wantCode: ErrCodeBadRequest,
},
{
name: "unauthorized",
fn: func() *Error { return Unauthorized("invalid token") },
wantCode: ErrCodeUnauthorized,
},
{
name: "forbidden",
fn: func() *Error { return Forbidden("access denied") },
wantCode: ErrCodeForbidden,
},
{
name: "not_found",
fn: func() *Error { return NotFound("resource missing") },
wantCode: ErrCodeNotFound,
},
{
name: "internal_server",
fn: func() *Error { return InternalServer("server error") },
wantCode: ErrCodeInternalServer,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
err := tt.fn()
s.Equal(tt.wantCode, err.Code)
})
}
}
func (s *ErrorsTestSuite) TestPackageNotFound() {
err := PackageNotFound("lodash", "4.17.21")
s.Equal(ErrCodePackageNotFound, err.Code)
s.Equal("Package lodash@4.17.21 not found", err.Message)
s.NotNil(err.Details)
details, ok := err.Details.(map[string]string)
s.True(ok)
s.Equal("lodash", details["package"])
s.Equal("4.17.21", details["version"])
}
func (s *ErrorsTestSuite) TestQuotaExceeded() {
limit := int64(1000000)
err := QuotaExceeded(limit)
s.Equal(ErrCodeQuotaExceeded, err.Code)
s.NotNil(err.Details)
details, ok := err.Details.(map[string]interface{})
s.True(ok)
s.Equal(limit, details["limit_bytes"])
}
func (s *ErrorsTestSuite) TestUnwrap() {
cause := errors.New("root cause")
wrapped := Wrap(cause, ErrCodeDatabaseFailure, "db error")
unwrapped := wrapped.Unwrap()
s.Equal(cause, unwrapped)
}
// Benchmark tests
func BenchmarkNewError(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = New(ErrCodeNotFound, "test error")
}
}
func BenchmarkNewErrorWithDetails(b *testing.B) {
details := map[string]string{"key": "value"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = New(ErrCodeNotFound, "test error").WithDetails(details)
}
}
// Test edge cases
func (s *ErrorsTestSuite) TestEdgeCases() {
s.Run("nil_error_wrap", func() {
wrapped := Wrap(nil, ErrCodeInternalServer, "test")
s.Nil(wrapped.Cause)
})
s.Run("chained_wrapping", func() {
err1 := errors.New("base")
err2 := Wrap(err1, ErrCodeStorageFailure, "storage")
err3 := Wrap(err2, ErrCodeInternalServer, "internal")
s.True(errors.Is(err3, err2))
s.True(errors.Is(err3, err1))
})
s.Run("large_details", func() {
largeDetails := make(map[string]string)
for i := 0; i < 1000; i++ {
largeDetails[string(rune(i))] = "value"
}
err := New(ErrCodeBadRequest, "test").WithDetails(largeDetails)
s.Equal(largeDetails, err.Details)
})
}
// Table-driven test for error codes
func TestGetHTTPStatus(t *testing.T) {
tests := []struct {
code string
expectedStatus int
}{
{ErrCodeBadRequest, 400},
{ErrCodeUnauthorized, 401},
{ErrCodeForbidden, 403},
{ErrCodeNotFound, 404},
{ErrCodeConflict, 409},
{ErrCodePayloadTooLarge, 413},
{ErrCodeChecksumMismatch, 422},
{ErrCodeRateLimited, 429},
{ErrCodeInternalServer, 500},
{ErrCodeDatabaseFailure, 500},
{ErrCodeUpstreamFailure, 502},
{ErrCodeServiceUnavailable, 503},
{"UNKNOWN_CODE", 500}, // Default
}
for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
assert.Equal(t, tt.expectedStatus, GetHTTPStatus(tt.code))
})
}
}
+90
View File
@@ -0,0 +1,90 @@
package errors
import (
"net/http"
"time"
json "github.com/goccy/go-json"
)
// Response is the standard API response envelope
type Response struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error *ErrorResponse `json:"error,omitempty"`
Metadata *ResponseMeta `json:"metadata,omitempty"`
}
// ErrorResponse contains error details
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
Details interface{} `json:"details,omitempty"`
Trace []string `json:"trace,omitempty"`
}
// ResponseMeta contains request metadata
type ResponseMeta struct {
RequestID string `json:"request_id"`
Timestamp string `json:"timestamp"`
Duration string `json:"duration,omitempty"`
Version string `json:"version"`
}
// WriteJSON writes a success response as JSON
func WriteJSON(w http.ResponseWriter, statusCode int, data interface{}, meta *ResponseMeta) {
response := Response{
Success: statusCode < 400,
Data: data,
Metadata: meta,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(response); err != nil {
// Fallback to simple error response
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode response"}}`, http.StatusInternalServerError)
}
}
// WriteError writes an error response as JSON
func WriteError(w http.ResponseWriter, statusCode int, err *Error, meta *ResponseMeta) {
errResp := &ErrorResponse{
Code: err.Code,
Message: err.Message,
Details: err.Details,
Trace: err.Trace,
}
response := Response{
Success: false,
Error: errResp,
Metadata: meta,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if encErr := json.NewEncoder(w).Encode(response); encErr != nil {
// Fallback to simple error response
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode error response"}}`, http.StatusInternalServerError)
}
}
// WriteErrorSimple writes an error without metadata
func WriteErrorSimple(w http.ResponseWriter, err *Error) {
statusCode := GetHTTPStatus(err.Code)
meta := &ResponseMeta{
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
WriteError(w, statusCode, err, meta)
}
// WriteJSONSimple writes a success response without metadata
func WriteJSONSimple(w http.ResponseWriter, statusCode int, data interface{}) {
meta := &ResponseMeta{
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
WriteJSON(w, statusCode, data, meta)
}
+178
View File
@@ -0,0 +1,178 @@
package health
import (
"context"
"net/http"
"sync"
"time"
json "github.com/goccy/go-json"
"github.com/lukaszraczylo/gohoarder/internal/version"
)
// Status represents component health status
type Status string
const (
StatusHealthy Status = "healthy"
StatusUnhealthy Status = "unhealthy"
StatusDegraded Status = "degraded"
)
// Check represents a single health check
type Check struct {
Name string `json:"name"`
Status Status `json:"status"`
Error string `json:"error,omitempty"`
Fn func(context.Context) (Status, string) `json:"-"`
}
// Response is the health check response
type Response struct {
Success bool `json:"success"`
Data *HealthData `json:"data,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}
// HealthData contains health check data
type HealthData struct {
Status Status `json:"status"`
Version string `json:"version"`
Uptime string `json:"uptime"`
Components map[string]*Component `json:"components"`
}
// Component represents a system component
type Component struct {
Status Status `json:"status"`
Details map[string]interface{} `json:"details,omitempty"`
Error string `json:"error,omitempty"`
}
// Metadata contains response metadata
type Metadata struct {
RequestID string `json:"request_id"`
Timestamp string `json:"timestamp"`
}
// Checker manages health checks
type Checker struct {
checks []*Check
startTime time.Time
mu sync.RWMutex
}
// New creates a new health checker
func New() *Checker {
return &Checker{
checks: make([]*Check, 0),
startTime: time.Now(),
}
}
// AddCheck adds a health check
func (c *Checker) AddCheck(name string, fn func(context.Context) (Status, string)) {
c.mu.Lock()
defer c.mu.Unlock()
c.checks = append(c.checks, &Check{
Name: name,
Fn: fn,
})
}
// RunChecks runs all health checks
func (c *Checker) RunChecks(ctx context.Context) *HealthData {
c.mu.RLock()
checks := make([]*Check, len(c.checks))
copy(checks, c.checks)
c.mu.RUnlock()
components := make(map[string]*Component)
overallStatus := StatusHealthy
for _, check := range checks {
status, errMsg := check.Fn(ctx)
components[check.Name] = &Component{
Status: status,
Error: errMsg,
}
// Determine overall status
if status == StatusUnhealthy {
overallStatus = StatusUnhealthy
} else if status == StatusDegraded && overallStatus == StatusHealthy {
overallStatus = StatusDegraded
}
}
return &HealthData{
Status: overallStatus,
Version: version.Version,
Uptime: time.Since(c.startTime).String(),
Components: components,
}
}
// HealthHandler returns an HTTP handler for health checks
func (c *Checker) HealthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
healthData := c.RunChecks(ctx)
response := Response{
Success: healthData.Status == StatusHealthy,
Data: healthData,
Metadata: &Metadata{
RequestID: r.Header.Get("X-Request-ID"),
Timestamp: time.Now().UTC().Format(time.RFC3339),
},
}
statusCode := http.StatusOK
if healthData.Status == StatusUnhealthy {
statusCode = http.StatusServiceUnavailable
} else if healthData.Status == StatusDegraded {
statusCode = http.StatusOK // 200 but degraded
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(response)
}
}
// ReadyHandler returns an HTTP handler for readiness checks
func (c *Checker) ReadyHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
healthData := c.RunChecks(ctx)
ready := healthData.Status != StatusUnhealthy
response := Response{
Success: ready,
Data: &HealthData{
Status: healthData.Status,
Components: healthData.Components,
},
Metadata: &Metadata{
RequestID: r.Header.Get("X-Request-ID"),
Timestamp: time.Now().UTC().Format(time.RFC3339),
},
}
statusCode := http.StatusOK
if !ready {
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(response)
}
}
+275
View File
@@ -0,0 +1,275 @@
package lock
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
)
var (
ErrLockNotAcquired = errors.New("lock not acquired")
ErrLockNotHeld = errors.New("lock not held by this instance")
ErrInvalidTTL = errors.New("invalid TTL: must be positive")
)
// Lock represents a distributed lock
type Lock struct {
client *redis.Client
key string
value string
ttl time.Duration
}
// Manager manages distributed locks using Redis
type Manager struct {
client *redis.Client
}
// Config holds Redis connection configuration
type Config struct {
Addr string
Password string
DB int
}
// NewManager creates a new lock manager
func NewManager(cfg Config) (*Manager, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, err
}
log.Info().
Str("addr", cfg.Addr).
Int("db", cfg.DB).
Msg("Connected to Redis for distributed locking")
return &Manager{
client: client,
}, nil
}
// Acquire attempts to acquire a lock with the given key and TTL
// Returns a Lock instance if successful, or an error if the lock is already held
func (m *Manager) Acquire(ctx context.Context, key string, ttl time.Duration) (*Lock, error) {
if ttl <= 0 {
return nil, ErrInvalidTTL
}
// Generate unique value for this lock instance
value, err := generateLockValue()
if err != nil {
return nil, err
}
// Try to acquire lock using SET NX (set if not exists)
success, err := m.client.SetNX(ctx, key, value, ttl).Result()
if err != nil {
log.Error().
Err(err).
Str("key", key).
Msg("Failed to acquire lock")
return nil, err
}
if !success {
log.Debug().
Str("key", key).
Msg("Lock already held by another instance")
return nil, ErrLockNotAcquired
}
log.Debug().
Str("key", key).
Dur("ttl", ttl).
Msg("Lock acquired successfully")
return &Lock{
client: m.client,
key: key,
value: value,
ttl: ttl,
}, nil
}
// TryAcquire attempts to acquire a lock, retrying for the specified duration
// Returns a Lock instance if successful within the timeout, or an error
func (m *Manager) TryAcquire(ctx context.Context, key string, ttl, timeout time.Duration) (*Lock, error) {
if ttl <= 0 {
return nil, ErrInvalidTTL
}
deadline := time.Now().Add(timeout)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
lock, err := m.Acquire(ctx, key, ttl)
if err == nil {
return lock, nil
}
if err != ErrLockNotAcquired {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
if time.Now().After(deadline) {
return nil, ErrLockNotAcquired
}
}
}
}
// Release releases the lock
// Returns an error if the lock is not held by this instance
func (l *Lock) Release(ctx context.Context) error {
// Use Lua script to ensure atomic check-and-delete
// Only delete if the value matches (ensures we own the lock)
script := redis.NewScript(`
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`)
result, err := script.Run(ctx, l.client, []string{l.key}, l.value).Result()
if err != nil {
log.Error().
Err(err).
Str("key", l.key).
Msg("Failed to release lock")
return err
}
// Result of 0 means the lock was not deleted (not owned by us)
if result.(int64) == 0 {
log.Warn().
Str("key", l.key).
Msg("Attempted to release lock not held by this instance")
return ErrLockNotHeld
}
log.Debug().
Str("key", l.key).
Msg("Lock released successfully")
return nil
}
// Extend extends the lock TTL
// Returns an error if the lock is not held by this instance
func (l *Lock) Extend(ctx context.Context, additionalTTL time.Duration) error {
// Use Lua script to ensure atomic check-and-extend
script := redis.NewScript(`
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
`)
newTTL := l.ttl + additionalTTL
result, err := script.Run(ctx, l.client, []string{l.key}, l.value, int(newTTL.Seconds())).Result()
if err != nil {
log.Error().
Err(err).
Str("key", l.key).
Msg("Failed to extend lock")
return err
}
if result.(int64) == 0 {
log.Warn().
Str("key", l.key).
Msg("Attempted to extend lock not held by this instance")
return ErrLockNotHeld
}
l.ttl = newTTL
log.Debug().
Str("key", l.key).
Dur("new_ttl", newTTL).
Msg("Lock TTL extended")
return nil
}
// IsHeld checks if the lock is still held by this instance
func (l *Lock) IsHeld(ctx context.Context) bool {
value, err := l.client.Get(ctx, l.key).Result()
if err != nil {
return false
}
return value == l.value
}
// Close closes the lock manager and its Redis connection
func (m *Manager) Close() error {
return m.client.Close()
}
// generateLockValue generates a cryptographically random lock value
func generateLockValue() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// WithLock executes a function while holding a distributed lock
// The lock is automatically released when the function returns
func (m *Manager) WithLock(ctx context.Context, key string, ttl time.Duration, fn func(context.Context) error) error {
lock, err := m.Acquire(ctx, key, ttl)
if err != nil {
return err
}
defer func() {
if err := lock.Release(context.Background()); err != nil {
log.Error().
Err(err).
Str("key", key).
Msg("Failed to release lock in defer")
}
}()
return fn(ctx)
}
// WithRetryLock executes a function while holding a distributed lock
// It retries acquisition for the specified timeout duration
func (m *Manager) WithRetryLock(ctx context.Context, key string, ttl, timeout time.Duration, fn func(context.Context) error) error {
lock, err := m.TryAcquire(ctx, key, ttl, timeout)
if err != nil {
return err
}
defer func() {
if err := lock.Release(context.Background()); err != nil {
log.Error().
Err(err).
Str("key", key).
Msg("Failed to release lock in defer")
}
}()
return fn(ctx)
}
+57
View File
@@ -0,0 +1,57 @@
package logger
import (
"os"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
// Config contains logger configuration
type Config struct {
Level string // debug, info, warn, error
Format string // json, pretty
}
// Init initializes the global logger
func Init(cfg Config) error {
// Set log level
level, err := zerolog.ParseLevel(cfg.Level)
if err != nil {
level = zerolog.InfoLevel
}
zerolog.SetGlobalLevel(level)
// Set time format
zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs
// Set format
if cfg.Format == "pretty" {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05.000"})
} else {
// JSON format (default for production)
log.Logger = zerolog.New(os.Stdout).With().Timestamp().Logger()
}
return nil
}
// Get returns the global logger
func Get() *zerolog.Logger {
return &log.Logger
}
// WithFields returns a logger with additional fields
func WithFields(fields map[string]interface{}) *zerolog.Logger {
logger := log.Logger
for k, v := range fields {
logger = logger.With().Interface(k, v).Logger()
}
return &logger
}
// WithRequestID returns a logger with request ID
func WithRequestID(requestID string) *zerolog.Logger {
logger := log.With().Str("request_id", requestID).Logger()
return &logger
}
+65
View File
@@ -0,0 +1,65 @@
package logger
import (
"net/http"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
"github.com/rs/zerolog/log"
)
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
written int64
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
n, err := rw.ResponseWriter.Write(b)
rw.written += int64(n)
return n, err
}
// Middleware is HTTP middleware for request logging
func Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Generate request ID
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = "req_" + uuid.New().String()[:8]
}
// Wrap response writer
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Set request ID in response header
rw.Header().Set("X-Request-ID", requestID)
// Call next handler
next.ServeHTTP(rw, r)
// Log request
duration := time.Since(start)
log.Info().
Str("request_id", requestID).
Str("method", r.Method).
Str("path", r.URL.Path).
Str("remote_addr", r.RemoteAddr).
Str("user_agent", r.UserAgent()).
Int("status", rw.statusCode).
Int64("bytes", rw.written).
Dur("duration_ms", duration).
Msg("HTTP request")
})
}
+528
View File
@@ -0,0 +1,528 @@
package file
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/rs/zerolog/log"
)
// Store implements a file-based metadata store
type Store struct {
basePath string
mu sync.RWMutex
}
// Config holds file store configuration
type Config struct {
Path string
}
// New creates a new file-based metadata store
func New(cfg Config) (*Store, error) {
if cfg.Path == "" {
cfg.Path = "./metadata"
}
// Create directory if it doesn't exist
if err := os.MkdirAll(cfg.Path, 0755); err != nil {
return nil, fmt.Errorf("failed to create metadata directory: %w", err)
}
log.Info().
Str("path", cfg.Path).
Msg("File-based metadata store initialized")
return &Store{
basePath: cfg.Path,
}, nil
}
// SavePackage saves package metadata
func (s *Store) SavePackage(ctx context.Context, pkg *metadata.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
// Create registry directory
regDir := filepath.Join(s.basePath, pkg.Registry)
if err := os.MkdirAll(regDir, 0755); err != nil {
return err
}
// Save to file
filename := filepath.Join(regDir, fmt.Sprintf("%s-%s.json", pkg.Name, pkg.Version))
data, err := json.MarshalIndent(pkg, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// GetPackage retrieves package metadata
func (s *Store) GetPackage(ctx context.Context, registry, name, version string) (*metadata.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filename := filepath.Join(s.basePath, registry, fmt.Sprintf("%s-%s.json", name, version))
data, err := os.ReadFile(filename)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var pkg metadata.Package
if err := json.Unmarshal(data, &pkg); err != nil {
return nil, err
}
return &pkg, nil
}
// ListPackages lists all packages
func (s *Store) ListPackages(ctx context.Context, opts *metadata.ListOptions) ([]*metadata.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var packages []*metadata.Package
// Walk through all files
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil // Skip files we can't read
}
var pkg metadata.Package
if err := json.Unmarshal(data, &pkg); err != nil {
return nil // Skip invalid JSON
}
packages = append(packages, &pkg)
return nil
})
if err != nil {
return nil, err
}
// Apply pagination if options provided
if opts != nil {
if opts.Offset >= len(packages) {
return []*metadata.Package{}, nil
}
end := opts.Offset + opts.Limit
if end > len(packages) {
end = len(packages)
}
return packages[opts.Offset:end], nil
}
return packages, nil
}
// DeletePackage deletes package metadata
func (s *Store) DeletePackage(ctx context.Context, registry, name, version string) error {
s.mu.Lock()
defer s.mu.Unlock()
filename := filepath.Join(s.basePath, registry, fmt.Sprintf("%s-%s.json", name, version))
if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// SaveScanResult saves scan result
func (s *Store) SaveScanResult(ctx context.Context, result *metadata.ScanResult) error {
s.mu.Lock()
defer s.mu.Unlock()
// Create scans directory
scanDir := filepath.Join(s.basePath, "scans", result.Registry, result.PackageName)
if err := os.MkdirAll(scanDir, 0755); err != nil {
return err
}
// Save to file with timestamp
timestamp := time.Now().Unix()
filename := filepath.Join(scanDir, fmt.Sprintf("%s-%d.json", result.PackageVersion, timestamp))
data, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// UpdateDownloadCount increments download counter
func (s *Store) UpdateDownloadCount(ctx context.Context, registry, name, version string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Load package
pkg, err := s.GetPackage(ctx, registry, name, version)
if err != nil || pkg == nil {
return err
}
// Increment counter
pkg.DownloadCount++
pkg.LastAccessed = time.Now()
// Save back
return s.SavePackage(ctx, pkg)
}
// GetStats returns statistics for a registry
func (s *Store) GetStats(ctx context.Context, registry string) (*metadata.Stats, error) {
s.mu.RLock()
defer s.mu.RUnlock()
stats := &metadata.Stats{
Registry: registry,
LastUpdated: time.Now(),
}
// Walk through files and calculate stats
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil
}
var pkg metadata.Package
if err := json.Unmarshal(data, &pkg); err != nil {
return nil
}
// Filter by registry if specified
if registry != "" && pkg.Registry != registry {
return nil
}
stats.TotalPackages++
stats.TotalSize += pkg.Size
stats.TotalDownloads += pkg.DownloadCount
if pkg.SecurityScanned {
stats.ScannedPackages++
}
return nil
})
if err != nil {
return nil, err
}
return stats, nil
}
// GetScanResult retrieves latest scan result
func (s *Store) GetScanResult(ctx context.Context, registry, name, version string) (*metadata.ScanResult, error) {
s.mu.RLock()
defer s.mu.RUnlock()
scanDir := filepath.Join(s.basePath, "scans", registry, name)
pattern := filepath.Join(scanDir, fmt.Sprintf("%s-*.json", version))
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, nil
}
// Get the latest file
latestFile := matches[len(matches)-1]
data, err := os.ReadFile(latestFile)
if err != nil {
return nil, err
}
var result metadata.ScanResult
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return &result, nil
}
// Count returns total number of packages
func (s *Store) Count(ctx context.Context) (int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
count := 0
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && filepath.Ext(path) == ".json" && filepath.Dir(path) != filepath.Join(s.basePath, "scans") {
count++
}
return nil
})
if err != nil {
return 0, err
}
return count, nil
}
// Health checks if the store is healthy
func (s *Store) Health(ctx context.Context) error {
// Check if directory is accessible
_, err := os.Stat(s.basePath)
return err
}
// SaveCVEBypass saves a CVE bypass (admin only)
func (s *Store) SaveCVEBypass(ctx context.Context, bypass *metadata.CVEBypass) error {
s.mu.Lock()
defer s.mu.Unlock()
// Create bypasses directory
bypassesDir := filepath.Join(s.basePath, "bypasses")
if err := os.MkdirAll(bypassesDir, 0755); err != nil {
return err
}
// Save to file
filename := filepath.Join(bypassesDir, fmt.Sprintf("%s.json", bypass.ID))
data, err := json.MarshalIndent(bypass, "", " ")
if err != nil {
return err
}
return os.WriteFile(filename, data, 0644)
}
// GetActiveCVEBypasses retrieves all active (non-expired) CVE bypasses
func (s *Store) GetActiveCVEBypasses(ctx context.Context) ([]*metadata.CVEBypass, error) {
s.mu.RLock()
defer s.mu.RUnlock()
bypassesDir := filepath.Join(s.basePath, "bypasses")
var bypasses []*metadata.CVEBypass
now := time.Now()
// Read all bypass files
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
if os.IsNotExist(err) {
return nil // bypasses directory doesn't exist yet
}
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return err
}
var bypass metadata.CVEBypass
if err := json.Unmarshal(data, &bypass); err != nil {
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
return nil
}
// Only include active and non-expired bypasses
if bypass.Active && bypass.ExpiresAt.After(now) {
bypasses = append(bypasses, &bypass)
}
return nil
})
if err != nil {
return nil, err
}
return bypasses, nil
}
// ListCVEBypasses lists all CVE bypasses (including expired)
func (s *Store) ListCVEBypasses(ctx context.Context, opts *metadata.BypassListOptions) ([]*metadata.CVEBypass, error) {
s.mu.RLock()
defer s.mu.RUnlock()
bypassesDir := filepath.Join(s.basePath, "bypasses")
var bypasses []*metadata.CVEBypass
now := time.Now()
// Read all bypass files
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
if os.IsNotExist(err) {
return nil // bypasses directory doesn't exist yet
}
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return err
}
var bypass metadata.CVEBypass
if err := json.Unmarshal(data, &bypass); err != nil {
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
return nil
}
// Apply filters if options provided
if opts != nil {
if opts.Type != "" && bypass.Type != opts.Type {
return nil
}
if !opts.IncludeExpired && bypass.ExpiresAt.Before(now) {
return nil
}
if opts.ActiveOnly && !bypass.Active {
return nil
}
}
bypasses = append(bypasses, &bypass)
return nil
})
if err != nil {
return nil, err
}
// Apply limit and offset if specified
if opts != nil {
if opts.Offset > 0 && opts.Offset < len(bypasses) {
bypasses = bypasses[opts.Offset:]
} else if opts.Offset >= len(bypasses) {
return []*metadata.CVEBypass{}, nil
}
if opts.Limit > 0 && opts.Limit < len(bypasses) {
bypasses = bypasses[:opts.Limit]
}
}
return bypasses, nil
}
// DeleteCVEBypass deletes a CVE bypass by ID
func (s *Store) DeleteCVEBypass(ctx context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
filename := filepath.Join(s.basePath, "bypasses", fmt.Sprintf("%s.json", id))
err := os.Remove(filename)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("CVE bypass not found: %s", id)
}
return err
}
return nil
}
// CleanupExpiredBypasses removes expired bypasses
func (s *Store) CleanupExpiredBypasses(ctx context.Context) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
bypassesDir := filepath.Join(s.basePath, "bypasses")
count := 0
now := time.Now()
// Read all bypass files
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
if os.IsNotExist(err) {
return nil // bypasses directory doesn't exist yet
}
return err
}
if info.IsDir() || filepath.Ext(path) != ".json" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return err
}
var bypass metadata.CVEBypass
if err := json.Unmarshal(data, &bypass); err != nil {
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
return nil
}
// Delete if expired
if bypass.ExpiresAt.Before(now) {
if err := os.Remove(path); err != nil {
log.Warn().Err(err).Str("file", path).Msg("Failed to delete expired bypass")
} else {
count++
}
}
return nil
})
if err != nil {
return 0, err
}
return count, nil
}
// Close closes the store
func (s *Store) Close() error {
// Nothing to close for file-based store
return nil
}
+170
View File
@@ -0,0 +1,170 @@
package metadata
import (
"context"
"time"
)
// Store is an alias for MetadataStore for convenience
type Store = MetadataStore
// MetadataStore defines the interface for package metadata storage
type MetadataStore interface {
// SavePackage saves package metadata
SavePackage(ctx context.Context, pkg *Package) error
// GetPackage retrieves package metadata
GetPackage(ctx context.Context, registry, name, version string) (*Package, error)
// DeletePackage deletes package metadata
DeletePackage(ctx context.Context, registry, name, version string) error
// ListPackages lists packages with optional filtering
ListPackages(ctx context.Context, opts *ListOptions) ([]*Package, error)
// UpdateDownloadCount increments download counter
UpdateDownloadCount(ctx context.Context, registry, name, version string) error
// GetStats returns statistics
GetStats(ctx context.Context, registry string) (*Stats, error)
// SaveScanResult saves security scan result
SaveScanResult(ctx context.Context, result *ScanResult) error
// GetScanResult retrieves security scan result
GetScanResult(ctx context.Context, registry, name, version string) (*ScanResult, error)
// SaveCVEBypass saves a CVE bypass (admin only)
SaveCVEBypass(ctx context.Context, bypass *CVEBypass) error
// GetActiveCVEBypasses retrieves all active (non-expired) CVE bypasses
GetActiveCVEBypasses(ctx context.Context) ([]*CVEBypass, error)
// ListCVEBypasses lists all CVE bypasses (including expired)
ListCVEBypasses(ctx context.Context, opts *BypassListOptions) ([]*CVEBypass, error)
// DeleteCVEBypass deletes a CVE bypass by ID
DeleteCVEBypass(ctx context.Context, id string) error
// CleanupExpiredBypasses removes expired bypasses
CleanupExpiredBypasses(ctx context.Context) (int, error)
// Count returns total number of packages
Count(ctx context.Context) (int, error)
// Health checks metadata store health
Health(ctx context.Context) error
// Close closes the metadata store
Close() error
}
// Package represents package metadata
type Package struct {
ID string `json:"id"`
Registry string `json:"registry"` // npm, pypi, go
Name string `json:"name"` // Package name
Version string `json:"version"` // Package version
StorageKey string `json:"storage_key"` // Key in storage backend
Size int64 `json:"size"` // Package size in bytes
ChecksumMD5 string `json:"checksum_md5"` // MD5 checksum
ChecksumSHA256 string `json:"checksum_sha256"` // SHA256 checksum
UpstreamURL string `json:"upstream_url"` // Original upstream URL
CachedAt time.Time `json:"cached_at"` // When cached
LastAccessed time.Time `json:"last_accessed"` // Last access time
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never)
DownloadCount int64 `json:"download_count"` // Download counter
Metadata map[string]string `json:"metadata"` // Additional metadata
SecurityScanned bool `json:"security_scanned"` // Has been scanned
}
// ScanResult represents a security scan result
type ScanResult struct {
ID string `json:"id"`
Registry string `json:"registry"`
PackageName string `json:"package_name"`
PackageVersion string `json:"package_version"`
Scanner string `json:"scanner"` // trivy, osv, etc.
ScannedAt time.Time `json:"scanned_at"`
Status ScanStatus `json:"status"` // clean, vulnerable, error
VulnerabilityCount int `json:"vulnerability_count"`
Vulnerabilities []Vulnerability `json:"vulnerabilities"`
Details map[string]interface{} `json:"details"` // Scanner-specific details
}
// Vulnerability represents a security vulnerability
type Vulnerability struct {
ID string `json:"id"` // CVE-xxx, GHSA-xxx, etc.
Severity string `json:"severity"` // critical, high, medium, low
Title string `json:"title"`
Description string `json:"description"`
References []string `json:"references"`
FixedIn string `json:"fixed_in"` // Version where fixed
DetectedBy []string `json:"detected_by,omitempty"` // List of scanners that detected this vulnerability
}
// ScanStatus represents scan result status
type ScanStatus string
const (
ScanStatusClean ScanStatus = "clean"
ScanStatusVulnerable ScanStatus = "vulnerable"
ScanStatusError ScanStatus = "error"
ScanStatusPending ScanStatus = "pending"
)
// Stats represents metadata statistics
type Stats struct {
Registry string `json:"registry"`
TotalPackages int64 `json:"total_packages"`
TotalSize int64 `json:"total_size"`
TotalDownloads int64 `json:"total_downloads"`
ScannedPackages int64 `json:"scanned_packages"`
VulnerablePackages int64 `json:"vulnerable_packages"`
LastUpdated time.Time `json:"last_updated"`
}
// CVEBypass represents a temporary bypass for a CVE or package
type CVEBypass struct {
ID string `json:"id"` // Unique bypass ID
Type BypassType `json:"type"` // cve, package
Target string `json:"target"` // CVE ID (e.g., "CVE-2021-23337") or package (e.g., "npm/lodash@4.17.20")
Reason string `json:"reason"` // Why this bypass was created
CreatedBy string `json:"created_by"` // Admin user who created it
CreatedAt time.Time `json:"created_at"` // When created
ExpiresAt time.Time `json:"expires_at"` // When it expires
AppliesTo string `json:"applies_to,omitempty"` // Optional: limit to specific package (for CVE bypasses)
NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired
Active bool `json:"active"` // Can be deactivated without deletion
}
// BypassType represents the type of bypass
type BypassType string
const (
BypassTypeCVE BypassType = "cve" // Bypass specific CVE
BypassTypePackage BypassType = "package" // Bypass entire package
)
// BypassListOptions contains options for listing CVE bypasses
type BypassListOptions struct {
Type BypassType // Filter by type
IncludeExpired bool // Include expired bypasses
ActiveOnly bool // Only active bypasses
Limit int // Max results
Offset int // Pagination offset
}
// ListOptions contains options for listing packages
type ListOptions struct {
Registry string // Filter by registry
NamePrefix string // Filter by name prefix
MinSize int64 // Minimum package size
MaxSize int64 // Maximum package size
ScannedOnly bool // Only scanned packages
SinceDate time.Time // Packages cached since date
Limit int // Max results
Offset int // Pagination offset
SortBy string // Sort field (name, size, cached_at, download_count)
SortDesc bool // Sort descending
}
+707
View File
@@ -0,0 +1,707 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
goccy_json "github.com/goccy/go-json"
_ "modernc.org/sqlite"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/rs/zerolog/log"
)
// SQLiteStore implements metadata.MetadataStore using SQLite
type SQLiteStore struct {
db *sql.DB
mu sync.RWMutex
}
// Config holds SQLite configuration
type Config struct {
Path string // Database file path
MaxOpenConns int // Maximum open connections
MaxIdleConns int // Maximum idle connections
}
const schema = `
CREATE TABLE IF NOT EXISTS packages (
id TEXT PRIMARY KEY,
registry TEXT NOT NULL,
name TEXT NOT NULL,
version TEXT NOT NULL,
storage_key TEXT NOT NULL,
size INTEGER NOT NULL,
checksum_md5 TEXT,
checksum_sha256 TEXT,
upstream_url TEXT,
cached_at DATETIME NOT NULL,
last_accessed DATETIME NOT NULL,
expires_at DATETIME,
download_count INTEGER DEFAULT 0,
metadata TEXT,
security_scanned BOOLEAN DEFAULT 0,
UNIQUE(registry, name, version)
);
CREATE INDEX IF NOT EXISTS idx_packages_registry ON packages(registry);
CREATE INDEX IF NOT EXISTS idx_packages_name ON packages(name);
CREATE INDEX IF NOT EXISTS idx_packages_cached_at ON packages(cached_at);
CREATE INDEX IF NOT EXISTS idx_packages_last_accessed ON packages(last_accessed);
CREATE INDEX IF NOT EXISTS idx_packages_expires_at ON packages(expires_at);
CREATE TABLE IF NOT EXISTS scan_results (
id TEXT PRIMARY KEY,
registry TEXT NOT NULL,
package_name TEXT NOT NULL,
package_version TEXT NOT NULL,
scanner TEXT NOT NULL,
scanned_at DATETIME NOT NULL,
status TEXT NOT NULL,
vulnerability_count INTEGER DEFAULT 0,
vulnerabilities TEXT,
details TEXT,
UNIQUE(registry, package_name, package_version, scanner)
);
CREATE INDEX IF NOT EXISTS idx_scan_results_registry ON scan_results(registry);
CREATE INDEX IF NOT EXISTS idx_scan_results_package ON scan_results(package_name);
CREATE INDEX IF NOT EXISTS idx_scan_results_status ON scan_results(status);
CREATE TABLE IF NOT EXISTS cve_bypasses (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
target TEXT NOT NULL,
reason TEXT NOT NULL,
created_by TEXT NOT NULL,
created_at DATETIME NOT NULL,
expires_at DATETIME NOT NULL,
applies_to TEXT,
notify_on_expiry BOOLEAN DEFAULT 0,
active BOOLEAN DEFAULT 1
);
CREATE INDEX IF NOT EXISTS idx_cve_bypasses_type ON cve_bypasses(type);
CREATE INDEX IF NOT EXISTS idx_cve_bypasses_target ON cve_bypasses(target);
CREATE INDEX IF NOT EXISTS idx_cve_bypasses_expires_at ON cve_bypasses(expires_at);
CREATE INDEX IF NOT EXISTS idx_cve_bypasses_active ON cve_bypasses(active);
`
// New creates a new SQLite metadata store
func New(cfg Config) (*SQLiteStore, error) {
if cfg.Path == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "SQLite database path is required")
}
if cfg.MaxOpenConns == 0 {
cfg.MaxOpenConns = 10
}
if cfg.MaxIdleConns == 0 {
cfg.MaxIdleConns = 5
}
// Open database with WAL mode for better concurrency
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_busy_timeout=5000&_synchronous=NORMAL&_cache_size=2000", cfg.Path)
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SQLite database")
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(time.Hour)
// Create schema
if _, err := db.Exec(schema); err != nil {
db.Close()
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SQLite schema")
}
return &SQLiteStore{
db: db,
}, nil
}
// SavePackage saves package metadata
func (s *SQLiteStore) SavePackage(ctx context.Context, pkg *metadata.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
// Serialize metadata
metadataJSON, err := goccy_json.Marshal(pkg.Metadata)
if err != nil {
return errors.Wrap(err, errors.ErrCodeInternalServer, "failed to serialize package metadata")
}
var expiresAt interface{}
if pkg.ExpiresAt != nil {
expiresAt = pkg.ExpiresAt
}
query := `
INSERT INTO packages (
id, registry, name, version, storage_key, size,
checksum_md5, checksum_sha256, upstream_url,
cached_at, last_accessed, expires_at, download_count,
metadata, security_scanned
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(registry, name, version) DO UPDATE SET
storage_key = excluded.storage_key,
size = excluded.size,
checksum_md5 = excluded.checksum_md5,
checksum_sha256 = excluded.checksum_sha256,
upstream_url = excluded.upstream_url,
last_accessed = excluded.last_accessed,
expires_at = excluded.expires_at,
metadata = excluded.metadata,
security_scanned = excluded.security_scanned
`
_, err = s.db.ExecContext(ctx, query,
pkg.ID, pkg.Registry, pkg.Name, pkg.Version, pkg.StorageKey, pkg.Size,
pkg.ChecksumMD5, pkg.ChecksumSHA256, pkg.UpstreamURL,
pkg.CachedAt, pkg.LastAccessed, expiresAt, pkg.DownloadCount,
string(metadataJSON), pkg.SecurityScanned,
)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to save package metadata")
}
return nil
}
// GetPackage retrieves package metadata
func (s *SQLiteStore) GetPackage(ctx context.Context, registry, name, version string) (*metadata.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := `
SELECT id, registry, name, version, storage_key, size,
checksum_md5, checksum_sha256, upstream_url,
cached_at, last_accessed, expires_at, download_count,
metadata, security_scanned
FROM packages
WHERE registry = ? AND name = ? AND version = ?
`
var pkg metadata.Package
var metadataJSON string
var expiresAt sql.NullTime
err := s.db.QueryRowContext(ctx, query, registry, name, version).Scan(
&pkg.ID, &pkg.Registry, &pkg.Name, &pkg.Version, &pkg.StorageKey, &pkg.Size,
&pkg.ChecksumMD5, &pkg.ChecksumSHA256, &pkg.UpstreamURL,
&pkg.CachedAt, &pkg.LastAccessed, &expiresAt, &pkg.DownloadCount,
&metadataJSON, &pkg.SecurityScanned,
)
if err == sql.ErrNoRows {
return nil, errors.NotFound(fmt.Sprintf("package not found: %s/%s@%s", registry, name, version))
}
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get package metadata")
}
if expiresAt.Valid {
pkg.ExpiresAt = &expiresAt.Time
}
// Deserialize metadata
if metadataJSON != "" {
if err := goccy_json.Unmarshal([]byte(metadataJSON), &pkg.Metadata); err != nil {
log.Warn().Err(err).Msg("Failed to deserialize package metadata")
}
}
return &pkg, nil
}
// DeletePackage deletes package metadata
func (s *SQLiteStore) DeletePackage(ctx context.Context, registry, name, version string) error {
s.mu.Lock()
defer s.mu.Unlock()
query := "DELETE FROM packages WHERE registry = ? AND name = ? AND version = ?"
result, err := s.db.ExecContext(ctx, query, registry, name, version)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete package metadata")
}
rows, _ := result.RowsAffected()
if rows == 0 {
return errors.NotFound(fmt.Sprintf("package not found: %s/%s@%s", registry, name, version))
}
return nil
}
// ListPackages lists packages with optional filtering
func (s *SQLiteStore) ListPackages(ctx context.Context, opts *metadata.ListOptions) ([]*metadata.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := "SELECT id, registry, name, version, storage_key, size, checksum_md5, checksum_sha256, upstream_url, cached_at, last_accessed, expires_at, download_count, metadata, security_scanned FROM packages WHERE 1=1"
args := []interface{}{}
if opts != nil {
if opts.Registry != "" {
query += " AND registry = ?"
args = append(args, opts.Registry)
}
if opts.NamePrefix != "" {
query += " AND name LIKE ?"
args = append(args, opts.NamePrefix+"%")
}
if opts.MinSize > 0 {
query += " AND size >= ?"
args = append(args, opts.MinSize)
}
if opts.MaxSize > 0 {
query += " AND size <= ?"
args = append(args, opts.MaxSize)
}
if opts.ScannedOnly {
query += " AND security_scanned = 1"
}
if !opts.SinceDate.IsZero() {
query += " AND cached_at >= ?"
args = append(args, opts.SinceDate)
}
// Sorting
sortBy := "cached_at"
if opts.SortBy != "" {
sortBy = opts.SortBy
}
sortOrder := "ASC"
if opts.SortDesc {
sortOrder = "DESC"
}
query += fmt.Sprintf(" ORDER BY %s %s", sortBy, sortOrder)
// Pagination
if opts.Limit > 0 {
query += " LIMIT ?"
args = append(args, opts.Limit)
}
if opts.Offset > 0 {
query += " OFFSET ?"
args = append(args, opts.Offset)
}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list packages")
}
defer rows.Close()
var packages []*metadata.Package
for rows.Next() {
var pkg metadata.Package
var metadataJSON string
var expiresAt sql.NullTime
err := rows.Scan(
&pkg.ID, &pkg.Registry, &pkg.Name, &pkg.Version, &pkg.StorageKey, &pkg.Size,
&pkg.ChecksumMD5, &pkg.ChecksumSHA256, &pkg.UpstreamURL,
&pkg.CachedAt, &pkg.LastAccessed, &expiresAt, &pkg.DownloadCount,
&metadataJSON, &pkg.SecurityScanned,
)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to scan package row")
}
if expiresAt.Valid {
pkg.ExpiresAt = &expiresAt.Time
}
if metadataJSON != "" {
goccy_json.Unmarshal([]byte(metadataJSON), &pkg.Metadata)
}
packages = append(packages, &pkg)
}
return packages, nil
}
// UpdateDownloadCount increments download counter
func (s *SQLiteStore) UpdateDownloadCount(ctx context.Context, registry, name, version string) error {
s.mu.Lock()
defer s.mu.Unlock()
query := `
UPDATE packages
SET download_count = download_count + 1,
last_accessed = ?
WHERE registry = ? AND name = ? AND version = ?
`
_, err := s.db.ExecContext(ctx, query, time.Now(), registry, name, version)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to update download count")
}
return nil
}
// GetStats returns statistics
func (s *SQLiteStore) GetStats(ctx context.Context, registry string) (*metadata.Stats, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := `
SELECT
COUNT(*) as total_packages,
COALESCE(SUM(size), 0) as total_size,
COALESCE(SUM(download_count), 0) as total_downloads,
COALESCE(SUM(CASE WHEN security_scanned = 1 THEN 1 ELSE 0 END), 0) as scanned_packages
FROM packages
`
args := []interface{}{}
if registry != "" {
query += " WHERE registry = ?"
args = append(args, registry)
}
var stats metadata.Stats
stats.Registry = registry
stats.LastUpdated = time.Now()
err := s.db.QueryRowContext(ctx, query, args...).Scan(
&stats.TotalPackages,
&stats.TotalSize,
&stats.TotalDownloads,
&stats.ScannedPackages,
)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get stats")
}
// Count vulnerable packages
vulnQuery := `SELECT COUNT(*) FROM scan_results WHERE status = 'vulnerable'`
vulnArgs := []interface{}{}
if registry != "" {
vulnQuery += " AND registry = ?"
vulnArgs = append(vulnArgs, registry)
}
s.db.QueryRowContext(ctx, vulnQuery, vulnArgs...).Scan(&stats.VulnerablePackages)
return &stats, nil
}
// SaveScanResult saves security scan result
func (s *SQLiteStore) SaveScanResult(ctx context.Context, result *metadata.ScanResult) error {
s.mu.Lock()
defer s.mu.Unlock()
// Serialize vulnerabilities and details
vulnJSON, err := goccy_json.Marshal(result.Vulnerabilities)
if err != nil {
return errors.Wrap(err, errors.ErrCodeInternalServer, "failed to serialize vulnerabilities")
}
detailsJSON, err := goccy_json.Marshal(result.Details)
if err != nil {
return errors.Wrap(err, errors.ErrCodeInternalServer, "failed to serialize scan details")
}
query := `
INSERT INTO scan_results (
id, registry, package_name, package_version, scanner,
scanned_at, status, vulnerability_count, vulnerabilities, details
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(registry, package_name, package_version, scanner) DO UPDATE SET
scanned_at = excluded.scanned_at,
status = excluded.status,
vulnerability_count = excluded.vulnerability_count,
vulnerabilities = excluded.vulnerabilities,
details = excluded.details
`
_, err = s.db.ExecContext(ctx, query,
result.ID, result.Registry, result.PackageName, result.PackageVersion, result.Scanner,
result.ScannedAt, result.Status, result.VulnerabilityCount,
string(vulnJSON), string(detailsJSON),
)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to save scan result")
}
// Update package security_scanned flag
updateQuery := `UPDATE packages SET security_scanned = 1 WHERE registry = ? AND name = ? AND version = ?`
s.db.ExecContext(ctx, updateQuery, result.Registry, result.PackageName, result.PackageVersion)
return nil
}
// GetScanResult retrieves security scan result
func (s *SQLiteStore) GetScanResult(ctx context.Context, registry, name, version string) (*metadata.ScanResult, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := `
SELECT id, registry, package_name, package_version, scanner,
scanned_at, status, vulnerability_count, vulnerabilities, details
FROM scan_results
WHERE registry = ? AND package_name = ? AND package_version = ?
ORDER BY scanned_at DESC
LIMIT 1
`
var result metadata.ScanResult
var vulnJSON, detailsJSON string
err := s.db.QueryRowContext(ctx, query, registry, name, version).Scan(
&result.ID, &result.Registry, &result.PackageName, &result.PackageVersion, &result.Scanner,
&result.ScannedAt, &result.Status, &result.VulnerabilityCount,
&vulnJSON, &detailsJSON,
)
if err == sql.ErrNoRows {
return nil, errors.NotFound(fmt.Sprintf("scan result not found: %s/%s@%s", registry, name, version))
}
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get scan result")
}
// Deserialize
if vulnJSON != "" {
goccy_json.Unmarshal([]byte(vulnJSON), &result.Vulnerabilities)
}
if detailsJSON != "" {
goccy_json.Unmarshal([]byte(detailsJSON), &result.Details)
}
return &result, nil
}
// Count returns total number of packages
func (s *SQLiteStore) Count(ctx context.Context) (int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var count int
query := "SELECT COUNT(*) FROM packages"
err := s.db.QueryRowContext(ctx, query).Scan(&count)
if err != nil {
return 0, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to count packages")
}
return count, nil
}
// Health checks metadata store health
func (s *SQLiteStore) Health(ctx context.Context) error {
return s.db.PingContext(ctx)
}
// SaveCVEBypass saves a CVE bypass (admin only)
func (s *SQLiteStore) SaveCVEBypass(ctx context.Context, bypass *metadata.CVEBypass) error {
s.mu.Lock()
defer s.mu.Unlock()
query := `
INSERT INTO cve_bypasses (
id, type, target, reason, created_by, created_at,
expires_at, applies_to, notify_on_expiry, active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
type = excluded.type,
target = excluded.target,
reason = excluded.reason,
expires_at = excluded.expires_at,
applies_to = excluded.applies_to,
notify_on_expiry = excluded.notify_on_expiry,
active = excluded.active
`
_, err := s.db.ExecContext(ctx, query,
bypass.ID, bypass.Type, bypass.Target, bypass.Reason, bypass.CreatedBy,
bypass.CreatedAt, bypass.ExpiresAt, bypass.AppliesTo,
bypass.NotifyOnExpiry, bypass.Active,
)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to save CVE bypass")
}
return nil
}
// GetActiveCVEBypasses retrieves all active (non-expired) CVE bypasses
func (s *SQLiteStore) GetActiveCVEBypasses(ctx context.Context) ([]*metadata.CVEBypass, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := `
SELECT id, type, target, reason, created_by, created_at,
expires_at, applies_to, notify_on_expiry, active
FROM cve_bypasses
WHERE active = 1 AND expires_at > ?
ORDER BY created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, time.Now())
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get active CVE bypasses")
}
defer rows.Close()
var bypasses []*metadata.CVEBypass
for rows.Next() {
var bypass metadata.CVEBypass
var appliesTo sql.NullString
err := rows.Scan(
&bypass.ID, &bypass.Type, &bypass.Target, &bypass.Reason, &bypass.CreatedBy,
&bypass.CreatedAt, &bypass.ExpiresAt, &appliesTo,
&bypass.NotifyOnExpiry, &bypass.Active,
)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to scan CVE bypass row")
}
if appliesTo.Valid {
bypass.AppliesTo = appliesTo.String
}
bypasses = append(bypasses, &bypass)
}
return bypasses, nil
}
// ListCVEBypasses lists all CVE bypasses (including expired)
func (s *SQLiteStore) ListCVEBypasses(ctx context.Context, opts *metadata.BypassListOptions) ([]*metadata.CVEBypass, error) {
s.mu.RLock()
defer s.mu.RUnlock()
query := `
SELECT id, type, target, reason, created_by, created_at,
expires_at, applies_to, notify_on_expiry, active
FROM cve_bypasses
WHERE 1=1
`
args := []interface{}{}
if opts != nil {
if opts.Type != "" {
query += " AND type = ?"
args = append(args, opts.Type)
}
if !opts.IncludeExpired {
query += " AND expires_at > ?"
args = append(args, time.Now())
}
if opts.ActiveOnly {
query += " AND active = 1"
}
query += " ORDER BY created_at DESC"
if opts.Limit > 0 {
query += " LIMIT ?"
args = append(args, opts.Limit)
}
if opts.Offset > 0 {
query += " OFFSET ?"
args = append(args, opts.Offset)
}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list CVE bypasses")
}
defer rows.Close()
var bypasses []*metadata.CVEBypass
for rows.Next() {
var bypass metadata.CVEBypass
var appliesTo sql.NullString
err := rows.Scan(
&bypass.ID, &bypass.Type, &bypass.Target, &bypass.Reason, &bypass.CreatedBy,
&bypass.CreatedAt, &bypass.ExpiresAt, &appliesTo,
&bypass.NotifyOnExpiry, &bypass.Active,
)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to scan CVE bypass row")
}
if appliesTo.Valid {
bypass.AppliesTo = appliesTo.String
}
bypasses = append(bypasses, &bypass)
}
return bypasses, nil
}
// DeleteCVEBypass deletes a CVE bypass by ID
func (s *SQLiteStore) DeleteCVEBypass(ctx context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
query := "DELETE FROM cve_bypasses WHERE id = ?"
result, err := s.db.ExecContext(ctx, query, id)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete CVE bypass")
}
rows, _ := result.RowsAffected()
if rows == 0 {
return errors.NotFound(fmt.Sprintf("CVE bypass not found: %s", id))
}
return nil
}
// CleanupExpiredBypasses removes expired bypasses
func (s *SQLiteStore) CleanupExpiredBypasses(ctx context.Context) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
query := "DELETE FROM cve_bypasses WHERE expires_at <= ?"
result, err := s.db.ExecContext(ctx, query, time.Now())
if err != nil {
return 0, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to cleanup expired CVE bypasses")
}
rows, _ := result.RowsAffected()
return int(rows), nil
}
// Close closes the metadata store
func (s *SQLiteStore) Close() error {
return s.db.Close()
}
+188
View File
@@ -0,0 +1,188 @@
package metrics
import (
"net/http"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// HTTP metrics
HTTPRequestsTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"handler", "method", "status"},
)
HTTPRequestDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "gohoarder_http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"handler", "method"},
)
// Cache metrics
CacheRequests = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_cache_requests_total",
Help: "Total number of cache requests",
},
[]string{"status", "handler"}, // hit, miss, error
)
CacheSizeBytes = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "gohoarder_cache_size_bytes",
Help: "Current cache size in bytes",
},
[]string{"backend"},
)
CacheItemsTotal = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "gohoarder_cache_items_total",
Help: "Total number of cached items",
},
[]string{"handler"},
)
CacheEvictions = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_cache_evictions_total",
Help: "Total number of cache evictions",
},
[]string{"reason"}, // ttl, lru, manual
)
// Storage metrics
StorageOperations = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_storage_operations_total",
Help: "Total number of storage operations",
},
[]string{"backend", "operation", "status"}, // get, put, delete
)
StorageQuotaBytes = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "gohoarder_storage_quota_bytes",
Help: "Storage quota in bytes per project",
},
[]string{"project"},
)
// Upstream metrics
UpstreamRequests = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_upstream_requests_total",
Help: "Total number of upstream requests",
},
[]string{"registry", "status"},
)
UpstreamDuration = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "gohoarder_upstream_duration_seconds",
Help: "Upstream request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"registry"},
)
// Security metrics
SecurityScans = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_security_scans_total",
Help: "Total number of security scans",
},
[]string{"scanner", "result"}, // clean, blocked, error
)
VulnerabilitiesFound = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gohoarder_vulnerabilities_found_total",
Help: "Total number of vulnerabilities found",
},
[]string{"severity"}, // low, medium, high, critical
)
// Circuit breaker metrics
CircuitBreakerState = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "gohoarder_circuit_breaker_state",
Help: "Circuit breaker state (0=closed, 1=open, 2=half-open)",
},
[]string{"name"},
)
)
// Handler returns the Prometheus HTTP handler
func Handler() http.Handler {
return promhttp.Handler()
}
// RecordCacheHit records a cache hit
func RecordCacheHit(handler string) {
CacheRequests.WithLabelValues("hit", handler).Inc()
}
// RecordCacheMiss records a cache miss
func RecordCacheMiss(handler string) {
CacheRequests.WithLabelValues("miss", handler).Inc()
}
// RecordCacheError records a cache error
func RecordCacheError(handler string) {
CacheRequests.WithLabelValues("error", handler).Inc()
}
// UpdateCacheSize updates the cache size metric
func UpdateCacheSize(backend string, bytes int64) {
CacheSizeBytes.WithLabelValues(backend).Set(float64(bytes))
}
// UpdateCacheItems updates the cache items metric
func UpdateCacheItems(handler string, count int64) {
CacheItemsTotal.WithLabelValues(handler).Set(float64(count))
}
// RecordCacheEviction records a cache eviction
func RecordCacheEviction(reason string) {
CacheEvictions.WithLabelValues(reason).Inc()
}
// RecordStorageOperation records a storage operation
func RecordStorageOperation(backend, operation, status string) {
StorageOperations.WithLabelValues(backend, operation, status).Inc()
}
// UpdateStorageQuota updates the storage quota metric
func UpdateStorageQuota(project string, bytes int64) {
StorageQuotaBytes.WithLabelValues(project).Set(float64(bytes))
}
// RecordUpstreamRequest records an upstream request
func RecordUpstreamRequest(registry, status string) {
UpstreamRequests.WithLabelValues(registry, status).Inc()
}
// RecordSecurityScan records a security scan
func RecordSecurityScan(scanner, result string) {
SecurityScans.WithLabelValues(scanner, result).Inc()
}
// RecordVulnerability records a vulnerability finding
func RecordVulnerability(severity string) {
VulnerabilitiesFound.WithLabelValues(severity).Inc()
}
// UpdateCircuitBreakerState updates the circuit breaker state
func UpdateCircuitBreakerState(name string, state int) {
CircuitBreakerState.WithLabelValues(name).Set(float64(state))
}
+360
View File
@@ -0,0 +1,360 @@
package network
import (
"context"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/rs/zerolog/log"
"golang.org/x/time/rate"
)
// Client is an HTTP client with resilience features
type Client struct {
client *http.Client
rateLimiter *rate.Limiter
circuitBreaker *CircuitBreaker
retryConfig RetryConfig
}
// Config holds client configuration
type Config struct {
Timeout time.Duration // Request timeout
MaxRetries int // Max retry attempts
RetryDelay time.Duration // Initial retry delay
RateLimit float64 // Requests per second (0 = unlimited)
RateBurst int // Rate limiter burst
CircuitBreaker CircuitBreakerConfig
UserAgent string
MaxConnsPerHost int
}
// RetryConfig holds retry configuration
type RetryConfig struct {
MaxAttempts int
InitialDelay time.Duration
MaxDelay time.Duration
Multiplier float64
FixedDelays []time.Duration // If set, use these delays instead of exponential backoff
}
// CircuitBreakerConfig holds circuit breaker configuration
type CircuitBreakerConfig struct {
Enabled bool
FailureThreshold int // Failures before opening
SuccessThreshold int // Successes before closing
Timeout time.Duration // How long to stay open
HalfOpenMaxCalls int // Max calls in half-open state
}
// CircuitBreakerState represents circuit breaker state
type CircuitBreakerState int
const (
StateClosed CircuitBreakerState = iota
StateOpen
StateHalfOpen
)
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
config CircuitBreakerConfig
state CircuitBreakerState
failures int
successes int
lastFailureTime time.Time
halfOpenCalls int
mu sync.RWMutex
}
// NewClient creates a new HTTP client with resilience features
func NewClient(config Config) *Client {
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
if config.MaxRetries == 0 {
config.MaxRetries = 3
}
if config.RetryDelay == 0 {
config.RetryDelay = 1 * time.Second
}
if config.UserAgent == "" {
config.UserAgent = "GoHoarder/1.0"
}
if config.MaxConnsPerHost == 0 {
config.MaxConnsPerHost = 100
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: config.MaxConnsPerHost,
MaxConnsPerHost: config.MaxConnsPerHost,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
}
httpClient := &http.Client{
Timeout: config.Timeout,
Transport: transport,
}
var rateLimiter *rate.Limiter
if config.RateLimit > 0 {
if config.RateBurst == 0 {
config.RateBurst = int(config.RateLimit)
}
rateLimiter = rate.NewLimiter(rate.Limit(config.RateLimit), config.RateBurst)
}
var cb *CircuitBreaker
if config.CircuitBreaker.Enabled {
if config.CircuitBreaker.FailureThreshold == 0 {
config.CircuitBreaker.FailureThreshold = 5
}
if config.CircuitBreaker.SuccessThreshold == 0 {
config.CircuitBreaker.SuccessThreshold = 2
}
if config.CircuitBreaker.Timeout == 0 {
config.CircuitBreaker.Timeout = 60 * time.Second
}
if config.CircuitBreaker.HalfOpenMaxCalls == 0 {
config.CircuitBreaker.HalfOpenMaxCalls = 3
}
cb = &CircuitBreaker{
config: config.CircuitBreaker,
state: StateClosed,
}
}
return &Client{
client: httpClient,
rateLimiter: rateLimiter,
circuitBreaker: cb,
retryConfig: RetryConfig{
MaxAttempts: config.MaxRetries,
InitialDelay: config.RetryDelay,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
// Fixed delays: 1s, 5s, 10s for retry attempts 1, 2, 3
FixedDelays: []time.Duration{1 * time.Second, 5 * time.Second, 10 * time.Second},
},
}
}
// Get performs a GET request with resilience features
func (c *Client) Get(ctx context.Context, url string, headers map[string]string) (io.ReadCloser, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, 0, errors.Wrap(err, errors.ErrCodeUpstreamError, "failed to create request")
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := c.do(ctx, req)
if err != nil {
return nil, 0, err
}
return resp.Body, resp.StatusCode, nil
}
// do executes an HTTP request with retries and circuit breaker
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
// Check circuit breaker
if c.circuitBreaker != nil {
if !c.circuitBreaker.AllowRequest() {
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
return nil, errors.New(errors.ErrCodeCircuitOpen, "circuit breaker is open")
}
}
// Apply rate limiting
if c.rateLimiter != nil {
if err := c.rateLimiter.Wait(ctx); err != nil {
return nil, errors.Wrap(err, errors.ErrCodeRateLimited, "rate limit exceeded")
}
}
// Execute with retries
var lastErr error
delay := c.retryConfig.InitialDelay
for attempt := 0; attempt < c.retryConfig.MaxAttempts; attempt++ {
if attempt > 0 {
// Calculate delay: use fixed delays if configured, otherwise exponential backoff
if len(c.retryConfig.FixedDelays) > 0 {
// Use fixed delay schedule
delayIndex := attempt - 1
if delayIndex < len(c.retryConfig.FixedDelays) {
delay = c.retryConfig.FixedDelays[delayIndex]
} else {
// Use last delay if we run out of configured delays
delay = c.retryConfig.FixedDelays[len(c.retryConfig.FixedDelays)-1]
}
} else {
// Exponential backoff
delay = time.Duration(float64(delay) * c.retryConfig.Multiplier)
if delay > c.retryConfig.MaxDelay {
delay = c.retryConfig.MaxDelay
}
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
log.Debug().
Str("url", req.URL.String()).
Int("attempt", attempt+1).
Dur("delay", delay).
Msg("Retrying request")
}
resp, err := c.client.Do(req)
if err != nil {
lastErr = err
if c.circuitBreaker != nil {
c.circuitBreaker.RecordFailure()
}
continue
}
// Check if response is retryable
if c.isRetryable(resp.StatusCode) {
resp.Body.Close()
lastErr = fmt.Errorf("received retryable status code: %d", resp.StatusCode)
if c.circuitBreaker != nil {
c.circuitBreaker.RecordFailure()
}
continue
}
// Success
if c.circuitBreaker != nil {
c.circuitBreaker.RecordSuccess()
metrics.UpdateCircuitBreakerState("upstream", int(StateClosed))
}
return resp, nil
}
// All retries exhausted
if c.circuitBreaker != nil {
c.circuitBreaker.RecordFailure()
}
if lastErr != nil {
return nil, errors.Wrap(lastErr, errors.ErrCodeUpstreamFailure, "all retry attempts failed")
}
return nil, errors.New(errors.ErrCodeUpstreamFailure, "request failed without error")
}
// isRetryable checks if a status code should trigger a retry
func (c *Client) isRetryable(statusCode int) bool {
// Retry on server errors and some client errors
return statusCode >= 500 || statusCode == 408 || statusCode == 429
}
// AllowRequest checks if a request is allowed by the circuit breaker
func (cb *CircuitBreaker) AllowRequest() bool {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case StateClosed:
return true
case StateOpen:
// Check if timeout has elapsed
if time.Since(cb.lastFailureTime) > cb.config.Timeout {
cb.state = StateHalfOpen
cb.halfOpenCalls = 0
cb.successes = 0
log.Info().Msg("Circuit breaker transitioning to half-open")
metrics.UpdateCircuitBreakerState("upstream", int(StateHalfOpen))
return true
}
return false
case StateHalfOpen:
// Allow limited requests in half-open state
if cb.halfOpenCalls < cb.config.HalfOpenMaxCalls {
cb.halfOpenCalls++
return true
}
return false
default:
return false
}
}
// RecordSuccess records a successful request
func (cb *CircuitBreaker) RecordSuccess() {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case StateClosed:
cb.failures = 0
case StateHalfOpen:
cb.successes++
if cb.successes >= cb.config.SuccessThreshold {
cb.state = StateClosed
cb.failures = 0
cb.successes = 0
cb.halfOpenCalls = 0
log.Info().Msg("Circuit breaker closed")
metrics.UpdateCircuitBreakerState("upstream", int(StateClosed))
}
}
}
// RecordFailure records a failed request
func (cb *CircuitBreaker) RecordFailure() {
cb.mu.Lock()
defer cb.mu.Unlock()
cb.lastFailureTime = time.Now()
switch cb.state {
case StateClosed:
cb.failures++
if cb.failures >= cb.config.FailureThreshold {
cb.state = StateOpen
log.Warn().Int("failures", cb.failures).Msg("Circuit breaker opened")
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
}
case StateHalfOpen:
// Single failure in half-open returns to open
cb.state = StateOpen
cb.halfOpenCalls = 0
cb.successes = 0
log.Warn().Msg("Circuit breaker re-opened from half-open")
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
}
}
// GetState returns the current circuit breaker state
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mu.RLock()
defer cb.mu.RUnlock()
return cb.state
}
+407
View File
@@ -0,0 +1,407 @@
package network_test
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestClientGet tests the HTTP client Get method with various scenarios
func TestClientGet(t *testing.T) {
tests := []struct {
name string
serverBehavior func(*testing.T) *httptest.Server
config network.Config
headers map[string]string
wantErr bool
errContains string
validateBody func(*testing.T, io.ReadCloser)
validateStatus func(*testing.T, int)
}{
// GOOD: Successful GET request
{
name: "successful get request returns body",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 3,
},
validateBody: func(t *testing.T, body io.ReadCloser) {
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, "success", string(data))
},
validateStatus: func(t *testing.T, status int) {
assert.Equal(t, http.StatusOK, status)
},
},
// GOOD: Retry succeeds on second attempt
{
name: "retry succeeds after transient failure",
serverBehavior: func(t *testing.T) *httptest.Server {
var attemptCount int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt32(&attemptCount, 1)
if count == 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("retry-success"))
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 3,
RetryDelay: 10 * time.Millisecond,
},
validateBody: func(t *testing.T, body io.ReadCloser) {
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, "retry-success", string(data))
},
validateStatus: func(t *testing.T, status int) {
assert.Equal(t, http.StatusOK, status)
},
},
// GOOD: Headers are properly sent
{
name: "custom headers are sent correctly",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "application/json", r.Header.Get("Accept"))
assert.Equal(t, "Bearer token123", r.Header.Get("Authorization"))
w.WriteHeader(http.StatusOK)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 1,
},
headers: map[string]string{
"Accept": "application/json",
"Authorization": "Bearer token123",
},
validateStatus: func(t *testing.T, status int) {
assert.Equal(t, http.StatusOK, status)
},
},
// WRONG: Server returns 404 (non-retryable)
{
name: "404 error is not retried",
serverBehavior: func(t *testing.T) *httptest.Server {
var attemptCount int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&attemptCount, 1)
w.WriteHeader(http.StatusNotFound)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 3,
RetryDelay: 10 * time.Millisecond,
},
validateStatus: func(t *testing.T, status int) {
assert.Equal(t, http.StatusNotFound, status)
},
},
// WRONG: Server returns 429 (rate limited - retryable)
{
name: "429 rate limit triggers retry with fixed delays",
serverBehavior: func(t *testing.T) *httptest.Server {
var attemptCount int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := atomic.AddInt32(&attemptCount, 1)
if count <= 2 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success-after-rate-limit"))
}))
},
config: network.Config{
Timeout: 10 * time.Second,
MaxRetries: 3,
RetryDelay: 10 * time.Millisecond,
},
validateBody: func(t *testing.T, body io.ReadCloser) {
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Equal(t, "success-after-rate-limit", string(data))
},
},
// BAD: All retries exhausted
{
name: "all retries fail returns error",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 2,
RetryDelay: 10 * time.Millisecond,
},
wantErr: true,
errContains: "retry attempts failed",
},
// BAD: Server timeout
{
name: "server timeout returns error",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
},
config: network.Config{
Timeout: 50 * time.Millisecond,
MaxRetries: 1,
},
wantErr: true,
errContains: "context deadline exceeded",
},
// EDGE 1: Context timeout (deadline exceeded)
{
name: "context timeout stops retry",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusInternalServerError)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 5,
RetryDelay: 50 * time.Millisecond,
},
wantErr: true,
errContains: "context deadline exceeded",
},
// EDGE 2: Empty response body
{
name: "empty response body handled correctly",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 1,
},
validateBody: func(t *testing.T, body io.ReadCloser) {
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Empty(t, data)
},
},
// EDGE 3: Large response body
{
name: "large response body handled correctly",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
largeBody := strings.Repeat("a", 1024*1024) // 1MB
w.WriteHeader(http.StatusOK)
w.Write([]byte(largeBody))
}))
},
config: network.Config{
Timeout: 10 * time.Second,
MaxRetries: 1,
},
validateBody: func(t *testing.T, body io.ReadCloser) {
defer body.Close()
data, err := io.ReadAll(body)
require.NoError(t, err)
assert.Len(t, data, 1024*1024)
},
},
// EDGE 4: Circuit breaker enabled
{
name: "circuit breaker opens after failures",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 2,
RetryDelay: 10 * time.Millisecond,
CircuitBreaker: network.CircuitBreakerConfig{
Enabled: true,
FailureThreshold: 3,
SuccessThreshold: 2,
Timeout: 100 * time.Millisecond,
},
},
wantErr: true,
errContains: "retry attempts failed",
},
// EDGE 5: Rate limiting enabled
{
name: "rate limiter throttles requests",
serverBehavior: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
},
config: network.Config{
Timeout: 5 * time.Second,
MaxRetries: 1,
RateLimit: 10, // 10 req/sec
RateBurst: 1,
},
validateStatus: func(t *testing.T, status int) {
assert.Equal(t, http.StatusOK, status)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Arrange
server := tt.serverBehavior(t)
defer server.Close()
client := network.NewClient(tt.config)
ctx := context.Background()
// For context timeout test
if strings.Contains(tt.name, "context timeout") {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
}
// Act
body, status, err := client.Get(ctx, server.URL, tt.headers)
// Assert
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
require.NotNil(t, body)
if tt.validateBody != nil {
tt.validateBody(t, body)
} else {
body.Close()
}
if tt.validateStatus != nil {
tt.validateStatus(t, status)
}
})
}
}
// TestRetryDelays verifies fixed retry delays are used correctly
func TestRetryDelays(t *testing.T) {
var attemptTimes []time.Time
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptTimes = append(attemptTimes, time.Now())
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
client := network.NewClient(network.Config{
Timeout: 10 * time.Second,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
})
ctx := context.Background()
_, _, err := client.Get(ctx, server.URL, nil)
require.Error(t, err)
require.Len(t, attemptTimes, 3, "should have made exactly 3 attempts")
// Verify delays are approximately 1s, 5s, 10s (with some tolerance)
// Note: The actual implementation uses fixed delays [1s, 5s, 10s]
// but for this test we're using RetryDelay as base which would be used
// if FixedDelays wasn't set
}
// TestConcurrentRequests verifies the client is safe for concurrent use
func TestConcurrentRequests(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("concurrent-ok"))
}))
defer server.Close()
client := network.NewClient(network.Config{
Timeout: 5 * time.Second,
MaxRetries: 1,
})
const concurrent = 10
errs := make(chan error, concurrent)
// Launch concurrent requests
for i := 0; i < concurrent; i++ {
go func() {
ctx := context.Background()
body, status, err := client.Get(ctx, server.URL, nil)
if err != nil {
errs <- err
return
}
defer body.Close()
if status != http.StatusOK {
errs <- fmt.Errorf("unexpected status: %d", status)
return
}
data, err := io.ReadAll(body)
if err != nil {
errs <- err
return
}
if string(data) != "concurrent-ok" {
errs <- fmt.Errorf("unexpected body: %s", data)
return
}
errs <- nil
}()
}
// Wait for all to complete
for i := 0; i < concurrent; i++ {
err := <-errs
assert.NoError(t, err)
}
}
+311
View File
@@ -0,0 +1,311 @@
package prewarming
import (
"context"
"sync"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/rs/zerolog/log"
)
// PackageInfo represents a package to pre-warm
type PackageInfo struct {
Registry string
Name string
Version string
Priority int
}
// Worker handles background pre-warming of popular packages
type Worker struct {
cache *cache.Manager
analytics *analytics.Engine
client *network.Client
interval time.Duration
maxConcurrent int
enabled bool
stopChan chan struct{}
wg sync.WaitGroup
}
// Config holds pre-warming worker configuration
type Config struct {
Enabled bool
Interval time.Duration
MaxConcurrent int
TopPackages int
CacheManager *cache.Manager
Analytics *analytics.Engine
NetworkClient *network.Client
}
// NewWorker creates a new pre-warming worker
func NewWorker(cfg Config) *Worker {
if cfg.Interval <= 0 {
cfg.Interval = 1 * time.Hour
}
if cfg.MaxConcurrent <= 0 {
cfg.MaxConcurrent = 5
}
if cfg.TopPackages <= 0 {
cfg.TopPackages = 100
}
worker := &Worker{
cache: cfg.CacheManager,
analytics: cfg.Analytics,
client: cfg.NetworkClient,
interval: cfg.Interval,
maxConcurrent: cfg.MaxConcurrent,
enabled: cfg.Enabled,
stopChan: make(chan struct{}),
}
if cfg.Enabled {
log.Info().
Dur("interval", cfg.Interval).
Int("max_concurrent", cfg.MaxConcurrent).
Msg("Pre-warming worker initialized")
} else {
log.Info().Msg("Pre-warming worker disabled")
}
return worker
}
// Start begins the pre-warming worker
func (w *Worker) Start(ctx context.Context) {
if !w.enabled {
log.Debug().Msg("Pre-warming worker is disabled, not starting")
return
}
w.wg.Add(1)
go w.run(ctx)
log.Info().Msg("Pre-warming worker started")
}
// run is the main worker loop
func (w *Worker) run(ctx context.Context) {
defer w.wg.Done()
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
// Run immediately on start
w.prewarmPopularPackages(ctx)
for {
select {
case <-ctx.Done():
log.Info().Msg("Pre-warming worker stopping due to context cancellation")
return
case <-w.stopChan:
log.Info().Msg("Pre-warming worker stopped")
return
case <-ticker.C:
w.prewarmPopularPackages(ctx)
}
}
}
// prewarmPopularPackages fetches and caches popular packages
func (w *Worker) prewarmPopularPackages(ctx context.Context) {
log.Info().Msg("Starting pre-warming cycle")
// Get popular packages from analytics
popularPackages := w.analytics.GetTopPackages(100)
if len(popularPackages) == 0 {
log.Debug().Msg("No popular packages found for pre-warming")
return
}
// Get trending packages for additional candidates
trendingPackages := w.analytics.GetTrendingPackages(50)
// Combine and deduplicate
packages := w.combinePackages(popularPackages, trendingPackages)
log.Info().
Int("packages", len(packages)).
Msg("Identified packages for pre-warming")
// Create work queue
workChan := make(chan PackageInfo, len(packages))
for _, pkg := range packages {
workChan <- PackageInfo{
Registry: pkg.Registry,
Name: pkg.Name,
Version: "latest", // Pre-warm latest version
Priority: int(pkg.Downloads),
}
}
close(workChan)
// Start worker goroutines
var wg sync.WaitGroup
for i := 0; i < w.maxConcurrent; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
w.processPackages(ctx, workerID, workChan)
}(i)
}
wg.Wait()
log.Info().Msg("Pre-warming cycle completed")
}
// processPackages processes packages from the work queue
func (w *Worker) processPackages(ctx context.Context, workerID int, workChan <-chan PackageInfo) {
for pkg := range workChan {
select {
case <-ctx.Done():
return
default:
w.prewarmPackage(ctx, pkg, workerID)
}
}
}
// prewarmPackage fetches and caches a single package
func (w *Worker) prewarmPackage(ctx context.Context, pkg PackageInfo, workerID int) {
log.Debug().
Int("worker", workerID).
Str("registry", pkg.Registry).
Str("package", pkg.Name).
Str("version", pkg.Version).
Msg("Pre-warming package")
// Build URL based on registry
url := w.buildPackageURL(pkg)
if url == "" {
log.Warn().
Str("registry", pkg.Registry).
Str("package", pkg.Name).
Msg("Cannot build URL for registry")
return
}
// Fetch package from upstream
reqCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
body, statusCode, err := w.client.Get(reqCtx, url, nil)
if err != nil {
log.Error().
Err(err).
Str("package", pkg.Name).
Msg("Failed to fetch package for pre-warming")
return
}
defer body.Close()
if statusCode != 200 {
log.Warn().
Int("status", statusCode).
Str("package", pkg.Name).
Msg("Non-200 response for package")
return
}
// Cache the package
// In a real implementation, this would read the response body and store it
log.Info().
Str("package", pkg.Name).
Str("version", pkg.Version).
Msg("Successfully pre-warmed package")
}
// buildPackageURL builds the upstream URL for a package
func (w *Worker) buildPackageURL(pkg PackageInfo) string {
// This is simplified - in reality, each registry has different URL patterns
switch pkg.Registry {
case "npm":
return "https://registry.npmjs.org/" + pkg.Name
case "pypi":
return "https://pypi.org/simple/" + pkg.Name + "/"
case "go":
// Go modules use different URL patterns
return "https://proxy.golang.org/" + pkg.Name + "/@latest"
default:
return ""
}
}
// combinePackages merges popular and trending packages, removing duplicates
func (w *Worker) combinePackages(popular, trending []analytics.PopularPackage) []analytics.PopularPackage {
seen := make(map[string]bool)
result := make([]analytics.PopularPackage, 0, len(popular)+len(trending))
for _, pkg := range popular {
key := pkg.Registry + ":" + pkg.Name
if !seen[key] {
result = append(result, pkg)
seen[key] = true
}
}
for _, pkg := range trending {
key := pkg.Registry + ":" + pkg.Name
if !seen[key] {
result = append(result, pkg)
seen[key] = true
}
}
return result
}
// Stop gracefully stops the pre-warming worker
func (w *Worker) Stop() {
if !w.enabled {
return
}
log.Info().Msg("Stopping pre-warming worker")
close(w.stopChan)
w.wg.Wait()
log.Info().Msg("Pre-warming worker stopped")
}
// TriggerPrewarm manually triggers a pre-warming cycle
func (w *Worker) TriggerPrewarm(ctx context.Context) {
if !w.enabled {
log.Warn().Msg("Cannot trigger pre-warm: worker is disabled")
return
}
log.Info().Msg("Manual pre-warming triggered")
go w.prewarmPopularPackages(ctx)
}
// PrewarmPackage pre-warms a specific package
func (w *Worker) PrewarmPackage(ctx context.Context, registry, name, version string) error {
if !w.enabled {
log.Warn().Msg("Pre-warming worker is disabled")
return nil
}
pkg := PackageInfo{
Registry: registry,
Name: name,
Version: version,
Priority: 100,
}
w.prewarmPackage(ctx, pkg, 0)
return nil
}
// GetStatus returns the current status of the pre-warming worker
func (w *Worker) GetStatus() map[string]interface{} {
return map[string]interface{}{
"enabled": w.enabled,
"interval": w.interval.String(),
"max_concurrent": w.maxConcurrent,
}
}
+34
View File
@@ -0,0 +1,34 @@
package common
import (
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
)
// BaseHandler provides common functionality for all proxy handlers
type BaseHandler struct {
Cache *cache.Manager
Client *network.Client
Upstream string
Registry string
}
// Config holds common proxy configuration
type Config struct {
Upstream string // Upstream registry URL (e.g., registry.npmjs.org)
}
// GetRegistry returns the registry type
func (h *BaseHandler) GetRegistry() string {
return h.Registry
}
// NewBaseHandler creates a new base handler with common fields
func NewBaseHandler(cache *cache.Manager, client *network.Client, registry, upstream string) *BaseHandler {
return &BaseHandler{
Cache: cache,
Client: client,
Upstream: upstream,
Registry: registry,
}
}
+385
View File
@@ -0,0 +1,385 @@
package common
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNewBaseHandler tests base handler creation
func TestNewBaseHandler(t *testing.T) {
// Use nil for cache and client since we're only testing structure
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
require.NotNil(t, handler)
assert.Equal(t, "npm", handler.Registry)
assert.Equal(t, "https://registry.npmjs.org", handler.Upstream)
assert.Nil(t, handler.Cache)
assert.Nil(t, handler.Client)
}
// TestGetRegistry tests registry type retrieval
func TestGetRegistry(t *testing.T) {
tests := []struct {
name string
registry string
}{
{"npm registry", "npm"},
{"pypi registry", "pypi"},
{"go registry", "go"},
{"custom registry", "custom"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := &BaseHandler{Registry: tt.registry}
assert.Equal(t, tt.registry, handler.GetRegistry())
})
}
}
// TestHandleUpstreamError tests upstream error handling
func TestHandleUpstreamError(t *testing.T) {
tests := []struct {
name string
err error
url string
context string
wantStatus int
wantContain string
}{
// GOOD: Standard error
{
name: "connection error",
err: errors.New("connection refused"),
url: "https://registry.npmjs.org/react",
context: "package",
wantStatus: http.StatusBadGateway,
wantContain: "Failed to fetch package",
},
// WRONG: Timeout error
{
name: "timeout error",
err: context.DeadlineExceeded,
url: "https://registry.npmjs.org/lodash",
context: "metadata",
wantStatus: http.StatusBadGateway,
wantContain: "Failed to fetch metadata",
},
// EDGE: Empty context
{
name: "empty context",
err: errors.New("error"),
url: "https://example.com",
context: "",
wantStatus: http.StatusBadGateway,
wantContain: "Failed to fetch",
},
// EDGE: Long URL
{
name: "long URL",
err: errors.New("error"),
url: "https://registry.npmjs.org/@scope/very-long-package-name/versions/1.2.3",
context: "package",
wantStatus: http.StatusBadGateway,
wantContain: "Failed to fetch package",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
HandleUpstreamError(w, tt.err, tt.url, tt.context)
assert.Equal(t, tt.wantStatus, w.Code)
assert.Contains(t, w.Body.String(), tt.wantContain)
})
}
}
// TestCheckUpstreamStatus tests upstream status validation
func TestCheckUpstreamStatus(t *testing.T) {
tests := []struct {
name string
statusCode int
body io.ReadCloser
wantErr bool
errContains string
bodyClosed bool
}{
// GOOD: OK status
{
name: "200 OK",
statusCode: http.StatusOK,
body: io.NopCloser(strings.NewReader("success")),
wantErr: false,
},
// WRONG: Not found
{
name: "404 Not Found",
statusCode: http.StatusNotFound,
body: io.NopCloser(strings.NewReader("not found")),
wantErr: true,
errContains: "upstream returned status 404",
},
// WRONG: Server error
{
name: "500 Internal Server Error",
statusCode: http.StatusInternalServerError,
body: io.NopCloser(strings.NewReader("error")),
wantErr: true,
errContains: "upstream returned status 500",
},
// BAD: Unauthorized
{
name: "401 Unauthorized",
statusCode: http.StatusUnauthorized,
body: io.NopCloser(strings.NewReader("unauthorized")),
wantErr: true,
errContains: "upstream returned status 401",
},
// EDGE: Nil body
{
name: "nil body with error",
statusCode: http.StatusNotFound,
body: nil,
wantErr: true,
errContains: "upstream returned status 404",
},
// EDGE: Redirect status
{
name: "302 Found",
statusCode: http.StatusFound,
body: io.NopCloser(strings.NewReader("redirect")),
wantErr: true,
errContains: "upstream returned status 302",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CheckUpstreamStatus(tt.statusCode, tt.body)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
} else {
require.NoError(t, err)
}
})
}
}
// TestHandleInvalidRequest tests invalid request handling
func TestHandleInvalidRequest(t *testing.T) {
tests := []struct {
name string
registry string
wantStatus int
wantContain string
}{
{
name: "npm invalid request",
registry: "npm",
wantStatus: http.StatusBadRequest,
wantContain: "Invalid npm request",
},
{
name: "pypi invalid request",
registry: "pypi",
wantStatus: http.StatusBadRequest,
wantContain: "Invalid pypi request",
},
{
name: "go invalid request",
registry: "go",
wantStatus: http.StatusBadRequest,
wantContain: "Invalid go request",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
HandleInvalidRequest(w, tt.registry)
assert.Equal(t, tt.wantStatus, w.Code)
assert.Contains(t, w.Body.String(), tt.wantContain)
})
}
}
// TestHandleInternalError tests internal error handling
func TestHandleInternalError(t *testing.T) {
tests := []struct {
name string
err error
context string
wantStatus int
wantContain string
}{
{
name: "database error",
err: errors.New("database connection failed"),
context: "database",
wantStatus: http.StatusInternalServerError,
wantContain: "Internal error: database",
},
{
name: "cache error",
err: errors.New("cache write failed"),
context: "cache",
wantStatus: http.StatusInternalServerError,
wantContain: "Internal error: cache",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
HandleInternalError(w, tt.err, tt.context)
assert.Equal(t, tt.wantStatus, w.Code)
assert.Contains(t, w.Body.String(), tt.wantContain)
})
}
}
// Note: FetchFromUpstream tests would require mocking cache.Manager and network.Client
// which requires concrete implementations. Integration tests cover this functionality.
// TestWriteResponse tests HTTP response writing
func TestWriteResponse(t *testing.T) {
tests := []struct {
name string
data string
contentType string
wantStatus int
wantBody string
wantErr bool
}{
// GOOD: Write tarball
{
name: "write tarball",
data: "package data here",
contentType: "application/octet-stream",
wantStatus: http.StatusOK,
wantBody: "package data here",
wantErr: false,
},
// GOOD: Write JSON
{
name: "write JSON metadata",
data: `{"name":"react","version":"18.2.0"}`,
contentType: "application/json",
wantStatus: http.StatusOK,
wantBody: `{"name":"react","version":"18.2.0"}`,
wantErr: false,
},
// EDGE: Empty data
{
name: "empty data",
data: "",
contentType: "text/plain",
wantStatus: http.StatusOK,
wantBody: "",
wantErr: false,
},
// EDGE: Large data
{
name: "large data",
data: strings.Repeat("x", 100000),
contentType: "application/octet-stream",
wantStatus: http.StatusOK,
wantBody: strings.Repeat("x", 100000),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
entry := &cache.CacheEntry{
Data: io.NopCloser(bytes.NewReader([]byte(tt.data))),
}
err := WriteResponse(w, entry, tt.contentType)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.contentType, w.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, w.Body.String())
}
})
}
}
// TestBaseHandlerFields tests that BaseHandler fields are properly set
func TestBaseHandlerFields(t *testing.T) {
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
tests := []struct {
name string
field string
expected interface{}
}{
{"registry field", "registry", "npm"},
{"upstream field", "upstream", "https://registry.npmjs.org"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switch tt.field {
case "registry":
assert.Equal(t, tt.expected, handler.Registry)
case "upstream":
assert.Equal(t, tt.expected, handler.Upstream)
}
})
}
}
// TestProxyHandlerInterface tests that BaseHandler can be used as ProxyHandler
func TestProxyHandlerInterface(t *testing.T) {
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
// Verify GetRegistry works
registry := handler.GetRegistry()
assert.Equal(t, "npm", registry)
}
// TestConcurrentWriteResponse tests that WriteResponse is safe for concurrent use
func TestConcurrentWriteResponse(t *testing.T) {
const numGoroutines = 10
errs := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(n int) {
w := httptest.NewRecorder()
data := strings.Repeat("x", 1000)
entry := &cache.CacheEntry{
Data: io.NopCloser(bytes.NewReader([]byte(data))),
}
err := WriteResponse(w, entry, "text/plain")
errs <- err
}(i)
}
// Collect results
for i := 0; i < numGoroutines; i++ {
err := <-errs
assert.NoError(t, err)
}
}
+48
View File
@@ -0,0 +1,48 @@
package common
import (
"fmt"
"io"
"net/http"
"github.com/rs/zerolog/log"
)
// HandleUpstreamError logs an error and sends an HTTP 502 Bad Gateway response
// This is the common pattern used across all proxy handlers when upstream fetch fails
func HandleUpstreamError(w http.ResponseWriter, err error, url, context string) {
log.Error().
Err(err).
Str("url", url).
Str("context", context).
Msg("Failed to fetch from upstream")
http.Error(w, fmt.Sprintf("Failed to fetch %s", context), http.StatusBadGateway)
}
// CheckUpstreamStatus validates HTTP status code from upstream
// Returns error if status is not OK, closing body if needed
func CheckUpstreamStatus(statusCode int, body io.ReadCloser) error {
if statusCode != http.StatusOK {
if body != nil {
body.Close()
}
return fmt.Errorf("upstream returned status %d", statusCode)
}
return nil
}
// HandleInvalidRequest sends a 400 Bad Request response for invalid proxy requests
func HandleInvalidRequest(w http.ResponseWriter, registry string) {
http.Error(w, fmt.Sprintf("Invalid %s request", registry), http.StatusBadRequest)
}
// HandleInternalError logs an internal error and sends 500 response
func HandleInternalError(w http.ResponseWriter, err error, context string) {
log.Error().
Err(err).
Str("context", context).
Msg("Internal error processing request")
http.Error(w, fmt.Sprintf("Internal error: %s", context), http.StatusInternalServerError)
}
+58
View File
@@ -0,0 +1,58 @@
package common
import (
"context"
"io"
"net/http"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/rs/zerolog/log"
)
// FetchFromUpstream is a common helper to fetch content from upstream with caching
// This encapsulates the common pattern of: cache.Get -> network.Get -> error handling
func FetchFromUpstream(
ctx context.Context,
cacheManager *cache.Manager,
client *network.Client,
registry, name, version, upstreamURL string,
) (*cache.CacheEntry, error) {
entry, err := cacheManager.Get(ctx, registry, name, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := client.Get(ctx, upstreamURL, nil)
if err != nil {
return nil, "", err
}
if err := CheckUpstreamStatus(statusCode, body); err != nil {
return nil, "", err
}
return body, upstreamURL, nil
})
if err != nil {
log.Error().
Err(err).
Str("url", upstreamURL).
Str("registry", registry).
Str("name", name).
Str("version", version).
Msg("Failed to fetch package from upstream")
return nil, err
}
return entry, nil
}
// WriteResponse writes the cache entry data to the HTTP response writer
// Sets appropriate content type and handles errors
func WriteResponse(w http.ResponseWriter, entry *cache.CacheEntry, contentType string) error {
defer entry.Data.Close()
w.Header().Set("Content-Type", contentType)
if _, err := io.Copy(w, entry.Data); err != nil {
log.Error().Err(err).Msg("Failed to write response")
return err
}
return nil
}
+29
View File
@@ -0,0 +1,29 @@
package common
import (
"context"
"net/http"
"time"
)
// ProxyHandler defines the common interface for all registry proxies
type ProxyHandler interface {
http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
// GetRegistry returns the registry type (npm, pypi, go)
GetRegistry() string
// Health checks if the proxy can reach its upstream
Health(ctx context.Context) error
}
// Stats represents proxy statistics
type Stats struct {
Registry string
TotalRequests int64
CacheHits int64
CacheMisses int64
UpstreamErrors int64
AvgResponseTime time.Duration
LastUpdated time.Time
}
+290
View File
@@ -0,0 +1,290 @@
package goproxy
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/rs/zerolog/log"
)
// Handler implements the GOPROXY protocol
type Handler struct {
cache *cache.Manager
client *network.Client
upstream string
sumDBURL string
}
// Config holds Go proxy configuration
type Config struct {
Upstream string // Upstream Go proxy (e.g., proxy.golang.org)
SumDBURL string // Checksum database URL
}
// New creates a new Go proxy handler
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
if config.Upstream == "" {
config.Upstream = "https://proxy.golang.org"
}
if config.SumDBURL == "" {
config.SumDBURL = "https://sum.golang.org"
}
return &Handler{
cache: cacheManager,
client: client,
upstream: config.Upstream,
sumDBURL: config.SumDBURL,
}
}
// ServeHTTP handles GOPROXY protocol requests
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Path is already stripped by http.StripPrefix in app.go
path := r.URL.Path
log.Debug().
Str("path", path).
Msg("Processing Go proxy request")
// Parse GOPROXY request
// Formats:
// /@v/list - list versions
// /@v/$version.info - version info
// /@v/$version.mod - go.mod file
// /@v/$version.zip - module zip
// /@latest - latest version
log.Debug().Str("path", path).Msg("Go proxy request")
// Route request based on path
if strings.HasPrefix(path, "/sumdb/") {
h.handleSumDB(ctx, w, r, path)
} else if strings.HasSuffix(path, "/@v/list") {
h.handleList(ctx, w, r, path)
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".info") {
h.handleInfo(ctx, w, r, path)
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".mod") {
h.handleMod(ctx, w, r, path)
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".zip") {
h.handleZip(ctx, w, r, path)
} else if strings.HasSuffix(path, "/@latest") {
h.handleLatest(ctx, w, r, path)
} else {
http.Error(w, "Invalid Go proxy request", http.StatusBadRequest)
}
}
// handleList handles /@v/list requests
func (h *Handler) handleList(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
modulePath := h.extractModulePath(path)
entry, err := h.cache.Get(ctx, "go", modulePath, "list", func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version list")
http.Error(w, "Failed to fetch version list", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
io.Copy(w, entry.Data)
}
// handleInfo handles /@v/$version.info requests
func (h *Handler) handleInfo(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
modulePath := h.extractModulePath(path)
version := h.extractVersion(path, ".info")
// Use .info suffix to distinguish from .mod and .zip in cache
cacheKey := modulePath + "/@v/" + version + ".info"
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version info")
http.Error(w, "Failed to fetch version info", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
io.Copy(w, entry.Data)
}
// handleMod handles /@v/$version.mod requests
func (h *Handler) handleMod(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
modulePath := h.extractModulePath(path)
version := h.extractVersion(path, ".mod")
// Use .mod suffix to distinguish from .info and .zip in cache
cacheKey := modulePath + "/@v/" + version + ".mod"
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch go.mod")
http.Error(w, "Failed to fetch go.mod", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
io.Copy(w, entry.Data)
}
// handleZip handles /@v/$version.zip requests
func (h *Handler) handleZip(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
modulePath := h.extractModulePath(path)
version := h.extractVersion(path, ".zip")
// Use .zip suffix to distinguish from .info and .mod in cache
cacheKey := modulePath + "/@v/" + version + ".zip"
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch module zip")
http.Error(w, "Failed to fetch module zip", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "application/zip")
io.Copy(w, entry.Data)
}
// handleLatest handles /@latest requests
func (h *Handler) handleLatest(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
modulePath := h.extractModulePath(path)
entry, err := h.cache.Get(ctx, "go", modulePath, "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch latest version")
http.Error(w, "Failed to fetch latest version", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
io.Copy(w, entry.Data)
}
// handleSumDB handles sumdb requests (checksum database)
func (h *Handler) handleSumDB(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
// path format: /sumdb/sum.golang.org/...
// Remove /sumdb/ prefix and proxy to sumdb URL
sumdbPath := strings.TrimPrefix(path, "/sumdb/sum.golang.org")
url := h.sumDBURL + sumdbPath
log.Debug().Str("url", url).Msg("Proxying sumdb request")
// Sumdb requests should not be cached, proxy directly
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch from sumdb")
http.Error(w, "Failed to fetch from sumdb", http.StatusBadGateway)
return
}
defer body.Close()
if statusCode != http.StatusOK {
log.Error().Int("status", statusCode).Str("url", url).Msg("Sumdb returned non-OK status")
http.Error(w, "Sumdb error", statusCode)
return
}
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
io.Copy(w, body)
}
// extractVersion extracts version from path
func (h *Handler) extractVersion(path, suffix string) string {
// path format: /module/path/@v/v1.2.3.suffix
parts := strings.Split(path, "/@v/")
if len(parts) != 2 {
return ""
}
return strings.TrimSuffix(parts[1], suffix)
}
// extractModulePath extracts the clean module path from a GOPROXY path
// Examples:
//
// /github.com/avast/retry-go/v4/@v/v4.6.1.zip -> github.com/avast/retry-go/v4
// /golang.org/x/net/@v/v0.40.0.mod -> golang.org/x/net
// /github.com/user/repo/@v/list -> github.com/user/repo
func (h *Handler) extractModulePath(path string) string {
// Remove leading slash
path = strings.TrimPrefix(path, "/")
// Split on /@v/ to get the module path
parts := strings.Split(path, "/@v/")
if len(parts) > 0 {
return parts[0]
}
// Fallback: remove /@latest suffix if present
return strings.TrimSuffix(path, "/@latest")
}
+294
View File
@@ -0,0 +1,294 @@
package npm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/rs/zerolog/log"
)
// Handler implements the NPM registry protocol
type Handler struct {
cache *cache.Manager
client *network.Client
upstream string
}
// Config holds NPM proxy configuration
type Config struct {
Upstream string // Upstream NPM registry (e.g., registry.npmjs.org)
}
// New creates a new NPM proxy handler
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
if config.Upstream == "" {
config.Upstream = "https://registry.npmjs.org"
}
return &Handler{
cache: cacheManager,
client: client,
upstream: config.Upstream,
}
}
// ServeHTTP handles NPM registry requests
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := strings.TrimPrefix(r.URL.Path, "/npm")
log.Debug().Str("path", path).Str("method", r.Method).Msg("NPM proxy request")
// Handle different NPM request types
// Check for tarballs FIRST before special endpoints (tarballs also contain "/-/")
if isTarballRequest(path) {
// Package tarball: /@scope/package/-/package-version.tgz
h.handleTarball(ctx, w, r, path)
} else if strings.Contains(path, "/-/") {
// Special NPM endpoints (e.g., /-/ping, /-/user/token)
h.handleSpecial(ctx, w, r, path)
} else if isPackageMetadata(path) {
// Package metadata: /@scope/package or /package
h.handleMetadata(ctx, w, r, path)
} else {
http.Error(w, "Invalid NPM request", http.StatusBadRequest)
}
}
// handleMetadata handles package metadata requests
func (h *Handler) handleMetadata(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
packageName := extractPackageName(path)
entry, err := h.cache.Get(ctx, "npm", packageName, "metadata", func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package metadata")
http.Error(w, "Failed to fetch package metadata", http.StatusBadGateway)
return
}
defer entry.Data.Close()
// Read metadata into memory for URL rewriting
var buf bytes.Buffer
if _, err := io.Copy(&buf, entry.Data); err != nil {
log.Error().Err(err).Msg("Failed to read metadata")
http.Error(w, "Failed to read metadata", http.StatusInternalServerError)
return
}
// Parse JSON metadata
var metadata map[string]interface{}
if err := json.Unmarshal(buf.Bytes(), &metadata); err != nil {
log.Error().Err(err).Msg("Failed to parse metadata JSON")
http.Error(w, "Failed to parse metadata", http.StatusInternalServerError)
return
}
// Rewrite tarball URLs to point to our proxy
proxyBaseURL := getProxyBaseURL(r)
rewriteMetadataURLs(metadata, h.upstream, proxyBaseURL)
// Serialize modified metadata
modifiedJSON, err := json.Marshal(metadata)
if err != nil {
log.Error().Err(err).Msg("Failed to serialize modified metadata")
http.Error(w, "Failed to serialize metadata", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.Write(modifiedJSON)
}
// handleTarball handles package tarball requests
func (h *Handler) handleTarball(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
packageName, version := extractTarballInfo(path)
// Construct proper upstream URL with /-/ format
// Format: https://registry.npmjs.org/package/-/package-version.tgz
tarballFilename := strings.ReplaceAll(packageName, "/", "-") + "-" + version + ".tgz"
url := fmt.Sprintf("%s/%s/-/%s", h.upstream, packageName, tarballFilename)
log.Debug().
Str("path", path).
Str("package", packageName).
Str("version", version).
Str("upstream_url", url).
Msg("Handling tarball request")
entry, err := h.cache.Get(ctx, "npm", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package tarball")
http.Error(w, "Failed to fetch package tarball", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "application/octet-stream")
io.Copy(w, entry.Data)
}
// handleSpecial handles special NPM endpoints
func (h *Handler) handleSpecial(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
// Don't cache special endpoints, proxy directly
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch special endpoint")
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
return
}
defer body.Close()
w.WriteHeader(statusCode)
io.Copy(w, body)
}
// isTarballRequest checks if the request is for a tarball
func isTarballRequest(path string) bool {
return strings.HasSuffix(path, ".tgz") || strings.HasSuffix(path, ".tar.gz")
}
// isPackageMetadata checks if the request is for package metadata
func isPackageMetadata(path string) bool {
// Package metadata doesn't have file extensions
return !isTarballRequest(path) && !strings.Contains(path, "/-/")
}
// extractPackageName extracts package name from path
func extractPackageName(path string) string {
// Remove leading slash
path = strings.TrimPrefix(path, "/")
// Handle scoped packages (@scope/package)
if strings.HasPrefix(path, "@") {
parts := strings.Split(path, "/")
if len(parts) >= 2 {
return parts[0] + "/" + parts[1]
}
}
// Regular package
parts := strings.Split(path, "/")
if len(parts) > 0 {
return parts[0]
}
return path
}
// extractTarballInfo extracts package name and version from tarball path
func extractTarballInfo(path string) (string, string) {
// Format: /@scope/package/-/package-version.tgz
// or: /package/-/package-version.tgz
// Also handle: /package/package-version.tgz (fallback)
// Try standard format with /-/
parts := strings.Split(path, "/-/")
if len(parts) == 2 {
packageName := extractPackageName(parts[0])
tarballName := parts[1]
tarballName = strings.TrimSuffix(tarballName, ".tgz")
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
// Remove package name prefix to get version
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
version := strings.TrimPrefix(tarballName, prefix)
return packageName, version
}
// Fallback: parse path without /-/
// Format: /package/package-version.tgz or /@scope/package/package-version.tgz
path = strings.TrimPrefix(path, "/")
pathParts := strings.Split(path, "/")
if len(pathParts) < 2 {
return "", ""
}
var packageName, tarballName string
// Handle scoped packages
if strings.HasPrefix(pathParts[0], "@") && len(pathParts) >= 3 {
packageName = pathParts[0] + "/" + pathParts[1]
tarballName = pathParts[len(pathParts)-1]
} else {
packageName = pathParts[0]
tarballName = pathParts[len(pathParts)-1]
}
tarballName = strings.TrimSuffix(tarballName, ".tgz")
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
// Remove package name prefix to get version
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
version := strings.TrimPrefix(tarballName, prefix)
return packageName, version
}
// getProxyBaseURL constructs the proxy base URL from the request
func getProxyBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
host := r.Host
return fmt.Sprintf("%s://%s/npm", scheme, host)
}
// rewriteMetadataURLs recursively rewrites upstream URLs to proxy URLs in metadata
func rewriteMetadataURLs(data interface{}, upstream, proxyBaseURL string) {
switch v := data.(type) {
case map[string]interface{}:
for key, value := range v {
if key == "tarball" || key == "dist" {
// Rewrite tarball URL
if strVal, ok := value.(string); ok {
v[key] = strings.Replace(strVal, upstream, proxyBaseURL, 1)
} else if distMap, ok := value.(map[string]interface{}); ok {
// Handle dist object with tarball field
rewriteMetadataURLs(distMap, upstream, proxyBaseURL)
}
} else {
// Recursively process nested objects
rewriteMetadataURLs(value, upstream, proxyBaseURL)
}
}
case []interface{}:
for _, item := range v {
rewriteMetadataURLs(item, upstream, proxyBaseURL)
}
}
}
+307
View File
@@ -0,0 +1,307 @@
package pypi
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"github.com/lukaszraczylo/gohoarder/pkg/cache"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/rs/zerolog/log"
)
// Handler implements the PyPI Simple API (PEP 503)
type Handler struct {
cache *cache.Manager
client *network.Client
upstream string
}
// Config holds PyPI proxy configuration
type Config struct {
Upstream string // Upstream PyPI index (e.g., pypi.org/simple)
}
// New creates a new PyPI proxy handler
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
if config.Upstream == "" {
config.Upstream = "https://pypi.org/simple"
}
return &Handler{
cache: cacheManager,
client: client,
upstream: config.Upstream,
}
}
// ServeHTTP handles PyPI Simple API requests
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
path := strings.TrimPrefix(r.URL.Path, "/pypi")
log.Debug().Str("path", path).Str("method", r.Method).Msg("PyPI proxy request")
// PEP 503 Simple API endpoints:
// / - index page
// /{package}/ - package page with links to files
if path == "/" || path == "" {
// Index page
h.handleIndex(ctx, w, r)
} else if isPackagePage(path) {
// Package page
h.handlePackagePage(ctx, w, r, path)
} else if isPackageFile(path) {
// Package file download (wheel or sdist)
h.handlePackageFile(ctx, w, r, path)
} else {
http.Error(w, "Invalid PyPI request", http.StatusBadRequest)
}
}
// handleIndex handles the index page request
func (h *Handler) handleIndex(ctx context.Context, w http.ResponseWriter, r *http.Request) {
url := h.upstream + "/"
entry, err := h.cache.Get(ctx, "pypi", "index", "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch PyPI index")
http.Error(w, "Failed to fetch PyPI index", http.StatusBadGateway)
return
}
defer entry.Data.Close()
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
io.Copy(w, entry.Data)
}
// handlePackagePage handles package page requests
func (h *Handler) handlePackagePage(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
url := h.upstream + path
packageName := extractPackageName(path)
entry, err := h.cache.Get(ctx, "pypi", packageName, "page", func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, url, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, url, nil
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package page")
http.Error(w, "Failed to fetch package page", http.StatusBadGateway)
return
}
defer entry.Data.Close()
// Read page into memory for URL rewriting
var buf bytes.Buffer
if _, err := io.Copy(&buf, entry.Data); err != nil {
log.Error().Err(err).Msg("Failed to read package page")
http.Error(w, "Failed to read package page", http.StatusInternalServerError)
return
}
// Rewrite package file URLs to point to our proxy
proxyBaseURL := getProxyBaseURL(r)
modifiedHTML := rewritePackagePageURLs(buf.String(), packageName, proxyBaseURL)
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
w.Write([]byte(modifiedHTML))
}
// handlePackageFile handles package file download requests
func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
packageName, version := extractPackageFileInfo(path)
// Check if we have the original URL from the rewritten package page
originalURL := r.URL.Query().Get("original_url")
// If no original URL provided, fall back to constructing from upstream
// (this handles direct file requests not from rewritten package pages)
if originalURL == "" {
originalURL = h.upstream + path
} else {
// Make the URL absolute if it's relative
if !strings.HasPrefix(originalURL, "http://") && !strings.HasPrefix(originalURL, "https://") {
originalURL = "https://pypi.org" + originalURL
}
}
entry, err := h.cache.Get(ctx, "pypi", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
body, statusCode, err := h.client.Get(ctx, originalURL, nil)
if err != nil {
return nil, "", err
}
if statusCode != http.StatusOK {
body.Close()
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
}
return body, originalURL, nil
})
if err != nil {
log.Error().Err(err).Str("url", originalURL).Msg("Failed to fetch package file")
http.Error(w, "Failed to fetch package file", http.StatusBadGateway)
return
}
defer entry.Data.Close()
// Determine content type based on file extension
contentType := "application/octet-stream"
if strings.HasSuffix(path, ".whl") {
contentType = "application/zip"
} else if strings.HasSuffix(path, ".tar.gz") {
contentType = "application/x-gzip"
} else if strings.HasSuffix(path, ".metadata") {
contentType = "text/plain; charset=UTF-8"
}
w.Header().Set("Content-Type", contentType)
io.Copy(w, entry.Data)
}
// isPackagePage checks if the request is for a package page
func isPackagePage(path string) bool {
// Package pages end with /
return strings.HasSuffix(path, "/")
}
// isPackageFile checks if the request is for a package file
func isPackageFile(path string) bool {
// Package files (not including .metadata files which need special handling)
return strings.HasSuffix(path, ".whl") ||
strings.HasSuffix(path, ".tar.gz") ||
strings.HasSuffix(path, ".zip") ||
strings.HasSuffix(path, ".egg")
}
// extractPackageName extracts package name from path
func extractPackageName(path string) string {
// Remove leading and trailing slashes
path = strings.Trim(path, "/")
// Remove /simple/ prefix if present
path = strings.TrimPrefix(path, "simple/")
// For package pages: /package-name/
// For files: /package-name/package-name-version.whl
parts := strings.Split(path, "/")
if len(parts) > 0 {
return parts[0]
}
return path
}
// extractPackageFileInfo extracts package name and version from file path
func extractPackageFileInfo(path string) (string, string) {
// Format: /package-name/package-name-version.whl
// or: /package-name/package-name-version.tar.gz
packageName := extractPackageName(path)
// Extract filename
parts := strings.Split(path, "/")
if len(parts) < 2 {
return packageName, ""
}
filename := parts[len(parts)-1]
// Remove extension
filename = strings.TrimSuffix(filename, ".whl")
filename = strings.TrimSuffix(filename, ".tar.gz")
filename = strings.TrimSuffix(filename, ".zip")
filename = strings.TrimSuffix(filename, ".egg")
// Extract version
// Filename format: package-name-version or package_name-version
// Version typically starts after last dash before build tags
versionParts := strings.Split(filename, "-")
if len(versionParts) >= 2 {
// Simple heuristic: version is the part that starts with a digit
for i := 1; i < len(versionParts); i++ {
if len(versionParts[i]) > 0 && versionParts[i][0] >= '0' && versionParts[i][0] <= '9' {
return packageName, versionParts[i]
}
}
}
return packageName, filename
}
// getProxyBaseURL constructs the proxy base URL from the request
func getProxyBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
host := r.Host
return fmt.Sprintf("%s://%s/pypi", scheme, host)
}
// rewritePackagePageURLs rewrites package file URLs in HTML to point to proxy
func rewritePackagePageURLs(html, packageName, proxyBaseURL string) string {
// PyPI Simple API uses href attributes in anchor tags
// We need to rewrite URLs pointing to files.pythonhosted.org or pypi.org
// We preserve the original URL as a query parameter so we can fetch from the correct CDN
// Regex pattern to match href URLs pointing to package files
// Matches: href="https://files.pythonhosted.org/packages/.../filename.whl"
// Also matches: href="../../packages/.../filename.whl"
pattern := regexp.MustCompile(`href="([^"]*?(\.whl|\.tar\.gz|\.zip|\.egg)[^"]*?)"`)
result := pattern.ReplaceAllStringFunc(html, func(match string) string {
// Extract the full URL and filename
urlPattern := regexp.MustCompile(`href="([^"]+)"`)
urlMatch := urlPattern.FindStringSubmatch(match)
if len(urlMatch) < 2 {
return match
}
originalURL := urlMatch[1]
// Extract just the filename
filenamePattern := regexp.MustCompile(`([^/]+\.(whl|tar\.gz|zip|egg))`)
filenameMatch := filenamePattern.FindString(originalURL)
if filenameMatch != "" {
// Rewrite to proxy URL format: /pypi/package-name/filename?original_url=...
// This preserves the original CDN URL so we can fetch from the correct location
baseURL := strings.TrimSuffix(proxyBaseURL, "/simple")
// URL encode the original URL
encodedURL := strings.ReplaceAll(originalURL, "&", "%26")
encodedURL = strings.ReplaceAll(encodedURL, "=", "%3D")
newURL := fmt.Sprintf(`href="%s/%s/%s?original_url=%s"`, baseURL, packageName, filenameMatch, encodedURL)
return newURL
}
return match
})
return result
}
+319
View File
@@ -0,0 +1,319 @@
package osv
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/config"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
"github.com/rs/zerolog/log"
)
const (
// ScannerName is the name of this scanner
ScannerName = "osv"
defaultOSVAPIURL = "https://api.osv.dev/v1/query"
)
// Scanner implements the Scanner interface using OSV.dev API
type Scanner struct {
config config.OSVConfig
httpClient *http.Client
}
// OSVRequest represents the request structure for OSV API
type OSVRequest struct {
Package PackageInfo `json:"package"`
Version string `json:"version,omitempty"`
}
// PackageInfo contains package ecosystem and name
type PackageInfo struct {
Name string `json:"name"`
Ecosystem string `json:"ecosystem"` // npm, PyPI, Go, etc.
}
// OSVResponse represents the response from OSV API
type OSVResponse struct {
Vulns []OSVVulnerability `json:"vulns"`
}
// OSVVulnerability represents a vulnerability in OSV format
type OSVVulnerability struct {
ID string `json:"id"`
Summary string `json:"summary"`
Details string `json:"details"`
Severity []OSVSeverity `json:"severity,omitempty"`
References []OSVReference `json:"references,omitempty"`
Affected []OSVAffected `json:"affected"`
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
}
// OSVSeverity represents severity information
type OSVSeverity struct {
Type string `json:"type"` // CVSS_V3, etc.
Score string `json:"score"` // Severity score
}
// OSVReference represents a reference link
type OSVReference struct {
Type string `json:"type"` // WEB, ADVISORY, etc.
URL string `json:"url"`
}
// OSVAffected represents affected package versions
type OSVAffected struct {
Package PackageInfo `json:"package"`
Ranges []OSVRange `json:"ranges,omitempty"`
Versions []string `json:"versions,omitempty"`
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"`
}
// OSVRange represents version ranges
type OSVRange struct {
Type string `json:"type"` // SEMVER, GIT, etc.
Events []OSVEvent `json:"events"`
}
// OSVEvent represents version range events
type OSVEvent struct {
Introduced string `json:"introduced,omitempty"`
Fixed string `json:"fixed,omitempty"`
LastAffected string `json:"last_affected,omitempty"`
}
// New creates a new OSV scanner
func New(cfg config.OSVConfig) *Scanner {
apiURL := cfg.APIURL
if apiURL == "" {
apiURL = defaultOSVAPIURL
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
return &Scanner{
config: cfg,
httpClient: &http.Client{
Timeout: timeout,
},
}
}
// Name returns the scanner name
func (s *Scanner) Name() string {
return ScannerName
}
// Scan scans a package for vulnerabilities using OSV.dev API
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
// Convert registry to OSV ecosystem
ecosystem := s.registryToEcosystem(registry)
// Build request
req := OSVRequest{
Package: PackageInfo{
Name: packageName,
Ecosystem: ecosystem,
},
Version: version,
}
// Marshal request
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal OSV request: %w", err)
}
// Create HTTP request
apiURL := s.config.APIURL
if apiURL == "" {
apiURL = defaultOSVAPIURL
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create OSV request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// Execute request
resp, err := s.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("OSV API request failed: %w", err)
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read OSV response: %w", err)
}
// Check status code
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OSV API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var osvResp OSVResponse
if err := json.Unmarshal(body, &osvResp); err != nil {
return nil, fmt.Errorf("failed to parse OSV response: %w", err)
}
// Convert to metadata.ScanResult
return s.convertOSVResult(&osvResp, registry, packageName, version), nil
}
// registryToEcosystem converts our registry name to OSV ecosystem
func (s *Scanner) registryToEcosystem(registry string) string {
switch strings.ToLower(registry) {
case "npm":
return "npm"
case "pypi":
return "PyPI"
case "go":
return "Go"
case "maven":
return "Maven"
case "nuget":
return "NuGet"
case "cargo", "crates":
return "crates.io"
case "rubygems":
return "RubyGems"
default:
return registry
}
}
// convertOSVResult converts OSV response to metadata.ScanResult
func (s *Scanner) convertOSVResult(osvResp *OSVResponse, registry, packageName, version string) *metadata.ScanResult {
vulnerabilities := make([]metadata.Vulnerability, 0, len(osvResp.Vulns))
severityCounts := make(map[string]int)
for _, vuln := range osvResp.Vulns {
// Determine severity from various sources
severity := s.determineSeverity(&vuln)
severityCounts[severity]++
// Extract references
references := make([]string, 0, len(vuln.References))
for _, ref := range vuln.References {
references = append(references, ref.URL)
}
// Find fixed version
fixedVersion := s.findFixedVersion(&vuln, version)
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
ID: vuln.ID,
Severity: severity,
Title: vuln.Summary,
Description: vuln.Details,
References: references,
FixedIn: fixedVersion,
})
}
// Determine overall status
status := metadata.ScanStatusClean
if len(vulnerabilities) > 0 {
status = metadata.ScanStatusVulnerable
}
return &metadata.ScanResult{
ID: uuid.New().String(),
Registry: registry,
PackageName: packageName,
PackageVersion: version,
Scanner: s.Name(),
ScannedAt: time.Now(),
Status: status,
VulnerabilityCount: len(vulnerabilities),
Vulnerabilities: vulnerabilities,
Details: map[string]interface{}{
"ecosystem": s.registryToEcosystem(registry),
"severity_counts": severityCounts,
},
}
}
// determineSeverity extracts severity from OSV vulnerability
func (s *Scanner) determineSeverity(vuln *OSVVulnerability) string {
// Try to get severity from CVSS
for _, sev := range vuln.Severity {
if sev.Type == "CVSS_V3" || sev.Type == "CVSS_V2" {
// Parse CVSS score to severity
score := sev.Score
if strings.Contains(strings.ToUpper(score), "CRITICAL") {
return "CRITICAL"
} else if strings.Contains(strings.ToUpper(score), "HIGH") {
return "HIGH"
} else if strings.Contains(strings.ToUpper(score), "MEDIUM") {
return "MEDIUM"
} else if strings.Contains(strings.ToUpper(score), "LOW") {
return "LOW"
}
}
}
// Check database_specific for severity
if vuln.DatabaseSpecific != nil {
if sev, ok := vuln.DatabaseSpecific["severity"].(string); ok {
return strings.ToUpper(sev)
}
}
// Default to MEDIUM if unknown
return "MEDIUM"
}
// findFixedVersion extracts the fixed version from OSV affected ranges
func (s *Scanner) findFixedVersion(vuln *OSVVulnerability, currentVersion string) string {
for _, affected := range vuln.Affected {
for _, r := range affected.Ranges {
for _, event := range r.Events {
if event.Fixed != "" {
return event.Fixed
}
}
}
}
return ""
}
// Health checks if OSV API is reachable
func (s *Scanner) Health(ctx context.Context) error {
// Make a simple request to check API availability
apiURL := s.config.APIURL
if apiURL == "" {
apiURL = defaultOSVAPIURL
}
req, err := http.NewRequestWithContext(ctx, "GET", strings.Replace(apiURL, "/query", "", 1), nil)
if err != nil {
return fmt.Errorf("failed to create health check request: %w", err)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("OSV API not reachable: %w", err)
}
defer resp.Body.Close()
log.Debug().Int("status", resp.StatusCode).Msg("OSV health check passed")
return nil
}
+139
View File
@@ -0,0 +1,139 @@
package scanner
import (
"context"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/rs/zerolog/log"
)
// RescanWorker handles periodic re-scanning of cached packages
type RescanWorker struct {
manager *Manager
metadataStore metadata.MetadataStore
interval time.Duration
stopCh chan struct{}
}
// NewRescanWorker creates a new rescan worker
func NewRescanWorker(manager *Manager, metadataStore metadata.MetadataStore, interval time.Duration) *RescanWorker {
return &RescanWorker{
manager: manager,
metadataStore: metadataStore,
interval: interval,
stopCh: make(chan struct{}),
}
}
// Start begins the periodic re-scanning process
func (w *RescanWorker) Start(ctx context.Context) {
if !w.manager.enabled || w.interval == 0 {
log.Info().Msg("Rescan worker disabled")
return
}
log.Info().
Dur("interval", w.interval).
Msg("Starting package rescan worker")
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
// Run initial scan immediately
w.rescanPackages(ctx)
for {
select {
case <-ticker.C:
w.rescanPackages(ctx)
case <-w.stopCh:
log.Info().Msg("Rescan worker stopped")
return
case <-ctx.Done():
log.Info().Msg("Rescan worker stopped (context cancelled)")
return
}
}
}
// Stop stops the rescan worker
func (w *RescanWorker) Stop() {
close(w.stopCh)
}
// rescanPackages re-scans packages that need updating
func (w *RescanWorker) rescanPackages(ctx context.Context) {
log.Info().Msg("Starting package rescan cycle")
// Get all packages
packages, err := w.metadataStore.ListPackages(ctx, &metadata.ListOptions{})
if err != nil {
log.Error().Err(err).Msg("Failed to list packages for rescan")
return
}
scanned := 0
skipped := 0
failed := 0
for _, pkg := range packages {
// Check if package needs rescanning
needsRescan, err := w.needsRescan(ctx, pkg)
if err != nil {
log.Error().
Err(err).
Str("package", pkg.Name).
Str("version", pkg.Version).
Msg("Failed to check rescan status")
failed++
continue
}
if !needsRescan {
skipped++
continue
}
// Rescan the package
// Note: We need the file path - we'll need to reconstruct it or get it from storage
// For now, we'll just log and skip actual rescanning
log.Info().
Str("registry", pkg.Registry).
Str("package", pkg.Name).
Str("version", pkg.Version).
Msg("Package needs rescanning")
// TODO: Implement actual rescanning by:
// 1. Retrieving package file from storage
// 2. Scanning it
// This would require access to storage backend
scanned++
}
log.Info().
Int("total", len(packages)).
Int("scanned", scanned).
Int("skipped", skipped).
Int("failed", failed).
Msg("Rescan cycle completed")
}
// needsRescan checks if a package needs to be rescanned
func (w *RescanWorker) needsRescan(ctx context.Context, pkg *metadata.Package) (bool, error) {
// Get latest scan result
scanResult, err := w.metadataStore.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
if err != nil {
// No scan result - needs scanning
return true, nil
}
// Check if scan is older than rescan interval
timeSinceLastScan := time.Since(scanResult.ScannedAt)
if timeSinceLastScan >= w.interval {
return true, nil
}
return false, nil
}
+432
View File
@@ -0,0 +1,432 @@
package scanner
import (
"context"
"fmt"
"strings"
"github.com/lukaszraczylo/gohoarder/pkg/config"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/scanner/osv"
"github.com/lukaszraczylo/gohoarder/pkg/scanner/trivy"
"github.com/rs/zerolog/log"
)
// Scanner defines the interface for security scanners
type Scanner interface {
// Name returns the scanner name
Name() string
// Scan scans a package for vulnerabilities
Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error)
// Health checks scanner health
Health(ctx context.Context) error
}
// DatabaseUpdater is implemented by scanners that need database updates
type DatabaseUpdater interface {
UpdateDatabase(ctx context.Context) error
}
// Manager manages multiple security scanners
type Manager struct {
scanners []Scanner
enabled bool
config config.SecurityConfig
metadataStore metadata.MetadataStore
}
// New creates a new scanner manager with configured scanners
func New(cfg config.SecurityConfig, metadataStore metadata.MetadataStore) (*Manager, error) {
manager := &Manager{
scanners: make([]Scanner, 0),
enabled: cfg.Enabled,
config: cfg,
metadataStore: metadataStore,
}
if !cfg.Enabled {
log.Info().Msg("Security scanning disabled")
return manager, nil
}
// Initialize Trivy scanner
if cfg.Scanners.Trivy.Enabled {
trivyScanner := trivy.New(cfg.Scanners.Trivy)
manager.RegisterScanner(trivyScanner)
log.Info().Msg("Trivy scanner enabled")
// Update database on startup if configured
if cfg.UpdateDBOnStartup {
if err := trivyScanner.UpdateDatabase(context.Background()); err != nil {
log.Warn().Err(err).Msg("Failed to update Trivy database on startup")
}
}
}
// Initialize OSV scanner
if cfg.Scanners.OSV.Enabled {
osvScanner := osv.New(cfg.Scanners.OSV)
manager.RegisterScanner(osvScanner)
log.Info().Msg("OSV scanner enabled")
}
if len(manager.scanners) == 0 {
log.Warn().Msg("Security scanning enabled but no scanners configured")
}
return manager, nil
}
// RegisterScanner registers a scanner
func (m *Manager) RegisterScanner(scanner Scanner) {
m.scanners = append(m.scanners, scanner)
}
// ScanPackage scans a package using all registered scanners and saves results
func (m *Manager) ScanPackage(ctx context.Context, registry, packageName, version string, filePath string) error {
if !m.enabled {
return nil
}
log.Info().
Str("registry", registry).
Str("package", packageName).
Str("version", version).
Msg("Starting security scan")
// Collect results from all scanners
var scanResults []*metadata.ScanResult
scannerNames := make([]string, 0)
for _, scanner := range m.scanners {
result, err := scanner.Scan(ctx, registry, packageName, version, filePath)
if err != nil {
log.Error().
Err(err).
Str("scanner", scanner.Name()).
Str("package", packageName).
Msg("Scanner failed")
continue
}
scanResults = append(scanResults, result)
scannerNames = append(scannerNames, scanner.Name())
log.Info().
Str("scanner", scanner.Name()).
Str("package", packageName).
Str("status", string(result.Status)).
Int("vulnerabilities", result.VulnerabilityCount).
Msg("Scan completed")
}
// If no scanners succeeded, return
if len(scanResults) == 0 {
log.Warn().
Str("package", packageName).
Msg("All scanners failed, no results to save")
return nil
}
// Merge and deduplicate results from all scanners
mergedResult := m.mergeResults(scanResults, scannerNames)
// Save consolidated result to metadata store
if err := m.metadataStore.SaveScanResult(ctx, mergedResult); err != nil {
log.Error().
Err(err).
Str("package", packageName).
Msg("Failed to save consolidated scan result")
return err
}
log.Info().
Str("package", packageName).
Str("status", string(mergedResult.Status)).
Int("total_vulnerabilities", mergedResult.VulnerabilityCount).
Int("unique_cves", len(mergedResult.Vulnerabilities)).
Strs("scanners", scannerNames).
Msg("Consolidated scan results saved")
return nil
}
// mergeResults merges and deduplicates scan results from multiple scanners
func (m *Manager) mergeResults(results []*metadata.ScanResult, scannerNames []string) *metadata.ScanResult {
if len(results) == 0 {
return nil
}
// Use first result as base
merged := &metadata.ScanResult{
ID: results[0].ID,
Registry: results[0].Registry,
PackageName: results[0].PackageName,
PackageVersion: results[0].PackageVersion,
Scanner: strings.Join(scannerNames, "+"), // Combined scanner name
ScannedAt: results[0].ScannedAt,
Status: metadata.ScanStatusClean,
Vulnerabilities: make([]metadata.Vulnerability, 0),
Details: make(map[string]interface{}),
}
// Use map for deduplication - key is CVE ID in uppercase
vulnMap := make(map[string]*metadata.Vulnerability)
severityCounts := make(map[string]int)
// Merge vulnerabilities from all scanners
for i, result := range results {
scannerName := scannerNames[i]
// Track scanner details
merged.Details[scannerName] = result.Details
// Add/merge vulnerabilities
for _, vuln := range result.Vulnerabilities {
cveKey := strings.ToUpper(vuln.ID)
// Check if CVE already exists
if existing, exists := vulnMap[cveKey]; exists {
// CVE found by multiple scanners - merge information
log.Debug().
Str("cve", vuln.ID).
Strs("existing_scanners", existing.DetectedBy).
Str("new_scanner", scannerName).
Msg("CVE found by multiple scanners, merging")
// Add scanner to DetectedBy list
existing.DetectedBy = append(existing.DetectedBy, scannerName)
// Prefer higher severity if different
if m.compareSeverity(vuln.Severity, existing.Severity) > 0 {
existing.Severity = vuln.Severity
}
// Merge references (deduplicate URLs)
refSet := make(map[string]bool)
for _, ref := range existing.References {
refSet[ref] = true
}
for _, ref := range vuln.References {
if !refSet[ref] {
existing.References = append(existing.References, ref)
refSet[ref] = true
}
}
// Prefer fixed_in version if not already set
if existing.FixedIn == "" && vuln.FixedIn != "" {
existing.FixedIn = vuln.FixedIn
}
} else {
// New CVE - add to map
vulnCopy := vuln
vulnCopy.DetectedBy = []string{scannerName}
vulnMap[cveKey] = &vulnCopy
}
}
// Update status to worst case
if result.Status == metadata.ScanStatusVulnerable {
merged.Status = metadata.ScanStatusVulnerable
} else if result.Status == metadata.ScanStatusPending && merged.Status != metadata.ScanStatusVulnerable {
merged.Status = metadata.ScanStatusPending
}
}
// Convert map to slice and count severities
for _, vuln := range vulnMap {
merged.Vulnerabilities = append(merged.Vulnerabilities, *vuln)
severityCounts[strings.ToUpper(vuln.Severity)]++
}
// Update counts
merged.VulnerabilityCount = len(merged.Vulnerabilities)
merged.Details["severity_counts"] = severityCounts
merged.Details["deduplication_summary"] = fmt.Sprintf(
"Merged results from %d scanners (%s)",
len(scannerNames),
strings.Join(scannerNames, ", "),
)
return merged
}
// compareSeverity returns >0 if s1 is more severe than s2, <0 if less, 0 if equal
func (m *Manager) compareSeverity(s1, s2 string) int {
severityOrder := map[string]int{
"CRITICAL": 4,
"HIGH": 3,
"MEDIUM": 2,
"LOW": 1,
"UNKNOWN": 0,
}
v1 := severityOrder[strings.ToUpper(s1)]
v2 := severityOrder[strings.ToUpper(s2)]
return v1 - v2
}
// CheckVulnerabilities checks if a package exceeds vulnerability thresholds
func (m *Manager) CheckVulnerabilities(ctx context.Context, registry, packageName, version string) (bool, string, error) {
if !m.enabled {
return false, "", nil
}
// Get active CVE bypasses from database
bypasses, err := m.metadataStore.GetActiveCVEBypasses(ctx)
if err != nil {
log.Warn().Err(err).Msg("Failed to get CVE bypasses, continuing without bypasses")
bypasses = []*metadata.CVEBypass{} // Continue without bypasses
}
// Check if entire package is bypassed
packageKey := fmt.Sprintf("%s/%s@%s", registry, packageName, version)
packageKeyNoVersion := fmt.Sprintf("%s/%s", registry, packageName)
for _, bypass := range bypasses {
if bypass.Type == metadata.BypassTypePackage && bypass.Active {
if bypass.Target == packageKey || bypass.Target == packageKeyNoVersion {
log.Info().
Str("package", packageKey).
Str("bypass_id", bypass.ID).
Str("reason", bypass.Reason).
Time("expires_at", bypass.ExpiresAt).
Msg("Package bypassed by admin")
return false, "", nil
}
}
}
// Get latest scan result
result, err := m.metadataStore.GetScanResult(ctx, registry, packageName, version)
if err != nil {
// No scan result found - allow download (will be scanned after)
return false, "", nil
}
// Build set of bypassed CVEs for fast lookup
bypassedCVEs := make(map[string]*metadata.CVEBypass)
for _, bypass := range bypasses {
if bypass.Type == metadata.BypassTypeCVE && bypass.Active {
// Check if bypass applies to this package (if AppliesTo is set)
if bypass.AppliesTo != "" && bypass.AppliesTo != packageKey && bypass.AppliesTo != packageKeyNoVersion {
continue // This bypass doesn't apply to this package
}
bypassedCVEs[strings.ToUpper(bypass.Target)] = bypass
}
}
// Count vulnerabilities by severity, excluding bypassed CVEs
severityCounts := make(map[string]int)
for _, vuln := range result.Vulnerabilities {
// Check if this CVE is bypassed
if bypass, ok := bypassedCVEs[strings.ToUpper(vuln.ID)]; ok {
log.Debug().
Str("cve", vuln.ID).
Str("package", packageName).
Str("bypass_id", bypass.ID).
Str("reason", bypass.Reason).
Time("expires_at", bypass.ExpiresAt).
Msg("CVE bypassed by admin")
continue
}
severityCounts[strings.ToUpper(vuln.Severity)]++
}
// Check against thresholds
thresholds := m.config.BlockThresholds
// Check critical
if thresholds.Critical >= 0 && severityCounts["CRITICAL"] > thresholds.Critical {
return true, fmt.Sprintf("Package has %d CRITICAL vulnerabilities (threshold: %d)",
severityCounts["CRITICAL"], thresholds.Critical), nil
}
// Check high
if thresholds.High >= 0 && severityCounts["HIGH"] > thresholds.High {
return true, fmt.Sprintf("Package has %d HIGH vulnerabilities (threshold: %d)",
severityCounts["HIGH"], thresholds.High), nil
}
// Check medium
if thresholds.Medium >= 0 && severityCounts["MEDIUM"] > thresholds.Medium {
return true, fmt.Sprintf("Package has %d MEDIUM vulnerabilities (threshold: %d)",
severityCounts["MEDIUM"], thresholds.Medium), nil
}
// Check low
if thresholds.Low >= 0 && severityCounts["LOW"] > thresholds.Low {
return true, fmt.Sprintf("Package has %d LOW vulnerabilities (threshold: %d)",
severityCounts["LOW"], thresholds.Low), nil
}
// Check block on severity
if m.config.BlockOnSeverity != "" && m.config.BlockOnSeverity != "none" {
severity := strings.ToUpper(m.config.BlockOnSeverity)
// Block if any vulnerabilities at or above the specified severity exist
switch severity {
case "CRITICAL":
if severityCounts["CRITICAL"] > 0 {
return true, fmt.Sprintf("Package has CRITICAL vulnerabilities"), nil
}
case "HIGH":
if severityCounts["CRITICAL"] > 0 || severityCounts["HIGH"] > 0 {
return true, fmt.Sprintf("Package has HIGH or CRITICAL vulnerabilities"), nil
}
case "MEDIUM":
if severityCounts["CRITICAL"] > 0 || severityCounts["HIGH"] > 0 || severityCounts["MEDIUM"] > 0 {
return true, fmt.Sprintf("Package has MEDIUM, HIGH, or CRITICAL vulnerabilities"), nil
}
case "LOW":
if len(result.Vulnerabilities) > 0 {
return true, fmt.Sprintf("Package has vulnerabilities"), nil
}
}
}
return false, "", nil
}
// UpdateDatabases updates vulnerability databases for all scanners
func (m *Manager) UpdateDatabases(ctx context.Context) error {
if !m.enabled {
return nil
}
log.Info().Msg("Updating vulnerability databases")
for _, scanner := range m.scanners {
if updater, ok := scanner.(DatabaseUpdater); ok {
if err := updater.UpdateDatabase(ctx); err != nil {
log.Error().
Err(err).
Str("scanner", scanner.Name()).
Msg("Failed to update database")
return err
}
}
}
log.Info().Msg("Vulnerability databases updated successfully")
return nil
}
// Health checks health of all scanners
func (m *Manager) Health(ctx context.Context) error {
if !m.enabled {
return nil
}
for _, scanner := range m.scanners {
if err := scanner.Health(ctx); err != nil {
return fmt.Errorf("scanner %s health check failed: %w", scanner.Name(), err)
}
}
return nil
}
+240
View File
@@ -0,0 +1,240 @@
package trivy
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/config"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
"github.com/rs/zerolog/log"
)
// ScannerName is the name of this scanner
const ScannerName = "trivy"
// Scanner implements the Scanner interface using Trivy
type Scanner struct {
config config.TrivyConfig
}
// TrivyResult represents Trivy JSON output structure
type TrivyResult struct {
SchemaVersion int `json:"SchemaVersion"`
ArtifactName string `json:"ArtifactName"`
ArtifactType string `json:"ArtifactType"`
Metadata TrivyMetadata `json:"Metadata"`
Results []TrivyVulnResult `json:"Results"`
}
type TrivyMetadata struct {
OS *TrivyOS `json:"OS,omitempty"`
RepoTags []string `json:"RepoTags,omitempty"`
RepoDigests []string `json:"RepoDigests,omitempty"`
ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"`
}
type TrivyOS struct {
Family string `json:"Family"`
Name string `json:"Name"`
}
type TrivyImageConfig struct {
Architecture string `json:"architecture"`
Created string `json:"created"`
}
type TrivyVulnResult struct {
Target string `json:"Target"`
Class string `json:"Class"`
Type string `json:"Type"`
Vulnerabilities []TrivyVulnerability `json:"Vulnerabilities"`
}
type TrivyVulnerability struct {
VulnerabilityID string `json:"VulnerabilityID"`
PkgName string `json:"PkgName"`
InstalledVersion string `json:"InstalledVersion"`
FixedVersion string `json:"FixedVersion"`
Severity string `json:"Severity"`
Title string `json:"Title"`
Description string `json:"Description"`
References []string `json:"References"`
PrimaryURL string `json:"PrimaryURL"`
}
// New creates a new Trivy scanner
func New(cfg config.TrivyConfig) *Scanner {
return &Scanner{
config: cfg,
}
}
// Name returns the scanner name
func (s *Scanner) Name() string {
return ScannerName
}
// UpdateDatabase updates Trivy's vulnerability database
func (s *Scanner) UpdateDatabase(ctx context.Context) error {
log.Info().Msg("Updating Trivy vulnerability database")
cmd := exec.CommandContext(ctx, "trivy", "image", "--download-db-only")
if s.config.CacheDB != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("TRIVY_CACHE_DIR=%s", s.config.CacheDB))
}
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to update Trivy database: %w (output: %s)", err, string(output))
}
log.Info().Msg("Trivy vulnerability database updated successfully")
return nil
}
// Scan scans a package for vulnerabilities using Trivy
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
// Set timeout
if s.config.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, s.config.Timeout)
defer cancel()
}
// Determine scan type based on registry
scanType := s.determineScanType(registry, filePath)
// Build Trivy command
args := []string{
scanType,
"--format", "json",
"--quiet",
filePath,
}
cmd := exec.CommandContext(ctx, "trivy", args...)
// Set cache directory if configured
if s.config.CacheDB != "" {
cmd.Env = append(os.Environ(), fmt.Sprintf("TRIVY_CACHE_DIR=%s", s.config.CacheDB))
}
// Execute scan
output, err := cmd.Output()
if err != nil {
// Check if it's a timeout
if ctx.Err() == context.DeadlineExceeded {
return &metadata.ScanResult{
ID: uuid.New().String(),
Registry: registry,
PackageName: packageName,
PackageVersion: version,
Scanner: s.Name(),
ScannedAt: time.Now(),
Status: metadata.ScanStatusError,
Details: map[string]interface{}{
"error": "scan timeout",
},
}, nil
}
return nil, fmt.Errorf("trivy scan failed: %w", err)
}
// Parse Trivy output
var trivyResult TrivyResult
if err := json.Unmarshal(output, &trivyResult); err != nil {
return nil, fmt.Errorf("failed to parse Trivy output: %w", err)
}
// Convert to metadata.ScanResult
return s.convertTrivyResult(&trivyResult, registry, packageName, version), nil
}
// determineScanType determines the appropriate Trivy scan type
func (s *Scanner) determineScanType(registry, filePath string) string {
// For now, use filesystem scan for packages
// Container image scanning would need different handling
ext := strings.ToLower(filePath[strings.LastIndex(filePath, ".")+1:])
switch registry {
case "npm":
return "fs" // Filesystem scan for npm packages
case "pypi":
return "fs" // Filesystem scan for Python packages
case "go":
return "fs" // Filesystem scan for Go modules
default:
// Check file extension
if ext == "tar" || ext == "tgz" || ext == "gz" {
return "fs"
}
return "fs"
}
}
// convertTrivyResult converts Trivy result to metadata.ScanResult
func (s *Scanner) convertTrivyResult(trivyResult *TrivyResult, registry, packageName, version string) *metadata.ScanResult {
vulnerabilities := make([]metadata.Vulnerability, 0)
severityCounts := make(map[string]int)
// Aggregate all vulnerabilities from all results
for _, result := range trivyResult.Results {
for _, vuln := range result.Vulnerabilities {
// Count by severity
severityCounts[strings.ToUpper(vuln.Severity)]++
// Add to vulnerabilities list
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
ID: vuln.VulnerabilityID,
Severity: strings.ToUpper(vuln.Severity),
Title: vuln.Title,
Description: vuln.Description,
References: vuln.References,
FixedIn: vuln.FixedVersion,
})
}
}
// Determine overall status
status := metadata.ScanStatusClean
if len(vulnerabilities) > 0 {
status = metadata.ScanStatusVulnerable
}
return &metadata.ScanResult{
ID: uuid.New().String(),
Registry: registry,
PackageName: packageName,
PackageVersion: version,
Scanner: s.Name(),
ScannedAt: time.Now(),
Status: status,
VulnerabilityCount: len(vulnerabilities),
Vulnerabilities: vulnerabilities,
Details: map[string]interface{}{
"artifact_name": trivyResult.ArtifactName,
"artifact_type": trivyResult.ArtifactType,
"severity_counts": severityCounts,
},
}
}
// Health checks if Trivy is available and working
func (s *Scanner) Health(ctx context.Context) error {
// Check if trivy command exists
cmd := exec.CommandContext(ctx, "trivy", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("trivy not available: %w (output: %s)", err, string(output))
}
log.Debug().Str("version", strings.TrimSpace(string(output))).Msg("Trivy health check passed")
return nil
}
+130
View File
@@ -0,0 +1,130 @@
package server
import (
"fmt"
"net/http"
"github.com/lukaszraczylo/gohoarder/pkg/config"
"github.com/lukaszraczylo/gohoarder/pkg/health"
"github.com/lukaszraczylo/gohoarder/pkg/logger"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
)
// Server wraps http.Server with configuration
type Server struct {
*http.Server
config *config.Config
healthChecker *health.Checker
}
// New creates a new HTTP server
func New(cfg *config.Config, healthChecker *health.Checker) (*Server, error) {
mux := http.NewServeMux()
// Register routes
registerRoutes(mux, cfg, healthChecker)
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: logger.Middleware(mux),
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
return &Server{
Server: srv,
config: cfg,
healthChecker: healthChecker,
}, nil
}
// registerRoutes registers all HTTP routes
func registerRoutes(mux *http.ServeMux, cfg *config.Config, healthChecker *health.Checker) {
// Health endpoints
mux.HandleFunc("/health", healthChecker.HealthHandler())
mux.HandleFunc("/health/ready", healthChecker.ReadyHandler())
// Metrics endpoint
mux.Handle("/metrics", metrics.Handler())
// API endpoints
mux.HandleFunc("/api/v1/info", handleInfo(cfg))
// Package manager proxy endpoints (placeholders for now)
if cfg.Handlers.Go.Enabled {
mux.HandleFunc("/go/", handleGoProxy())
}
if cfg.Handlers.NPM.Enabled {
mux.HandleFunc("/npm/", handleNPMProxy())
}
if cfg.Handlers.PyPI.Enabled {
mux.HandleFunc("/pypi/", handlePyPIProxy())
}
// Root endpoint
mux.HandleFunc("/", handleRoot())
}
// handleInfo returns server information
func handleInfo(cfg *config.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
info := map[string]interface{}{
"name": "GoHoarder",
"version": "dev",
"handlers": map[string]bool{
"go": cfg.Handlers.Go.Enabled,
"npm": cfg.Handlers.NPM.Enabled,
"pypi": cfg.Handlers.PyPI.Enabled,
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"success":true,"data":%v}`, toJSON(info))
}
}
// handleGoProxy handles Go module proxy requests
func handleGoProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// TODO: Implement Go proxy handler
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"Go proxy not yet implemented"}}`, http.StatusNotImplemented)
}
}
// handleNPMProxy handles NPM registry requests
func handleNPMProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// TODO: Implement NPM proxy handler
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"NPM proxy not yet implemented"}}`, http.StatusNotImplemented)
}
}
// handlePyPIProxy handles PyPI requests
func handlePyPIProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// TODO: Implement PyPI proxy handler
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"PyPI proxy not yet implemented"}}`, http.StatusNotImplemented)
}
}
// handleRoot handles root path
func handleRoot() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"success":true,"data":{"message":"GoHoarder - Universal Package Cache Proxy","docs":"https://github.com/lukaszraczylo/gohoarder"}}`)
}
}
// toJSON is a simple JSON encoder (replace with proper implementation)
func toJSON(v interface{}) string {
// Simplified for now - proper implementation would use goccy/go-json
return fmt.Sprintf("%v", v)
}
+415
View File
@@ -0,0 +1,415 @@
package filesystem
import (
"context"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/rs/zerolog/log"
)
// FilesystemStorage implements storage.StorageBackend for local filesystem
type FilesystemStorage struct {
basePath string
quota int64
mu sync.RWMutex
used int64
}
// New creates a new filesystem storage backend
func New(basePath string, quota int64) (*FilesystemStorage, error) {
// Create base directory if it doesn't exist
if err := os.MkdirAll(basePath, 0755); err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create base directory")
}
fs := &FilesystemStorage{
basePath: basePath,
quota: quota,
}
// Calculate initial usage
if err := fs.calculateUsage(); err != nil {
log.Warn().Err(err).Msg("Failed to calculate initial storage usage")
}
return fs, nil
}
// Get retrieves a file
func (fs *FilesystemStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
// Check context
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
path := fs.keyToPath(key)
file, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
metrics.RecordStorageOperation("filesystem", "get", "not_found")
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("filesystem", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open file")
}
metrics.RecordStorageOperation("filesystem", "get", "success")
return file, nil
}
// Put stores a file atomically
func (fs *FilesystemStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
// Check context
select {
case <-ctx.Done():
return ctx.Err()
default:
}
path := fs.keyToPath(key)
dir := filepath.Dir(path)
// Create directory
if err := os.MkdirAll(dir, 0755); err != nil {
metrics.RecordStorageOperation("filesystem", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create directory")
}
// Create temp file for atomic write
tempPath := path + ".tmp"
tempFile, err := os.Create(tempPath)
if err != nil {
metrics.RecordStorageOperation("filesystem", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create temp file")
}
// Calculate checksums while writing
// NOTE: MD5 is used for integrity verification (checksums), not cryptographic security
md5Hash := md5.New()
sha256Hash := sha256.New()
multiWriter := io.MultiWriter(tempFile, md5Hash, sha256Hash)
written, err := io.Copy(multiWriter, data)
if err != nil {
tempFile.Close()
os.Remove(tempPath)
metrics.RecordStorageOperation("filesystem", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write data")
}
if err := tempFile.Close(); err != nil {
os.Remove(tempPath)
metrics.RecordStorageOperation("filesystem", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to close temp file")
}
// Check quota
fs.mu.Lock()
if fs.quota > 0 && fs.used+written > fs.quota {
fs.mu.Unlock()
os.Remove(tempPath)
metrics.RecordStorageOperation("filesystem", "put", "quota_exceeded")
return errors.QuotaExceeded(fs.quota)
}
fs.used += written
fs.mu.Unlock()
// Verify checksums if provided
if opts != nil {
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
os.Remove(tempPath)
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
}
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
os.Remove(tempPath)
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
}
}
// Atomic rename
if err := os.Rename(tempPath, path); err != nil {
os.Remove(tempPath)
fs.mu.Lock()
fs.used -= written
currentUsed := fs.used
fs.mu.Unlock()
metrics.RecordStorageOperation("filesystem", "put", "error")
metrics.UpdateCacheSize("filesystem", currentUsed)
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename temp file")
}
fs.mu.RLock()
currentUsed := fs.used
fs.mu.RUnlock()
metrics.RecordStorageOperation("filesystem", "put", "success")
metrics.UpdateCacheSize("filesystem", currentUsed)
return nil
}
// Delete removes a file
func (fs *FilesystemStorage) Delete(ctx context.Context, key string) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
path := fs.keyToPath(key)
// Get size before deletion
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
metrics.RecordStorageOperation("filesystem", "delete", "not_found")
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("filesystem", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
}
size := info.Size()
if err := os.Remove(path); err != nil {
metrics.RecordStorageOperation("filesystem", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete file")
}
fs.mu.Lock()
fs.used -= size
currentUsed := fs.used
fs.mu.Unlock()
metrics.RecordStorageOperation("filesystem", "delete", "success")
metrics.UpdateCacheSize("filesystem", currentUsed)
return nil
}
// Exists checks if a file exists
func (fs *FilesystemStorage) Exists(ctx context.Context, key string) (bool, error) {
select {
case <-ctx.Done():
return false, ctx.Err()
default:
}
path := fs.keyToPath(key)
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence")
}
return true, nil
}
// List lists files with prefix
func (fs *FilesystemStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
searchPath := fs.keyToPath(prefix)
var objects []storage.StorageObject
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip errors
}
if info.IsDir() {
return nil
}
// Convert path back to key
relPath, _ := filepath.Rel(fs.basePath, path)
key := filepath.ToSlash(relPath)
objects = append(objects, storage.StorageObject{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
})
return nil
})
if err != nil && !os.IsNotExist(err) {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list files")
}
// Apply pagination if requested
if opts != nil {
start := opts.Offset
end := len(objects)
if opts.MaxResults > 0 && start+opts.MaxResults < end {
end = start + opts.MaxResults
}
if start < len(objects) {
objects = objects[start:end]
}
}
return objects, nil
}
// Stat gets file metadata
func (fs *FilesystemStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
path := fs.keyToPath(key)
info, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
}
return &storage.StorageInfo{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
}, nil
}
// GetQuota returns quota information
func (fs *FilesystemStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
fs.mu.RLock()
used := fs.used
fs.mu.RUnlock()
available := fs.quota - used
if available < 0 {
available = 0
}
return &storage.QuotaInfo{
Used: used,
Available: available,
Limit: fs.quota,
}, nil
}
// Health checks filesystem health
func (fs *FilesystemStorage) Health(ctx context.Context) error {
// Check if base path is accessible
if _, err := os.Stat(fs.basePath); err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "base path not accessible")
}
// Try to create a temp file (sanitize path to prevent traversal)
tempPath := filepath.Clean(filepath.Join(fs.basePath, ".health_check"))
f, err := os.Create(tempPath)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "cannot write to storage")
}
f.Close()
os.Remove(tempPath)
return nil
}
// Close closes the storage backend
func (fs *FilesystemStorage) Close() error {
// Nothing to close for filesystem
return nil
}
// GetLocalPath returns the local filesystem path for a storage key
// This implements storage.LocalPathProvider interface
func (fs *FilesystemStorage) GetLocalPath(ctx context.Context, key string) (string, error) {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
path := fs.keyToPath(key)
// Verify file exists
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return "", errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
return "", errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
}
return path, nil
}
// keyToPath converts a storage key to filesystem path
func (fs *FilesystemStorage) keyToPath(key string) string {
// Sanitize key to prevent path traversal
key = filepath.Clean(key)
// Remove any leading slashes or dots
key = strings.TrimPrefix(key, "/")
// Keep removing ../ until there are no more
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = strings.TrimPrefix(key, "../")
key = strings.TrimPrefix(key, "..\\")
}
// Final clean and ensure it's within base path
key = filepath.Clean(key)
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = ""
}
return filepath.Join(fs.basePath, key)
}
// calculateUsage calculates current storage usage
func (fs *FilesystemStorage) calculateUsage() error {
var total int64
err := filepath.Walk(fs.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip errors
}
if !info.IsDir() {
total += info.Size()
}
return nil
})
if err != nil {
return err
}
fs.mu.Lock()
fs.used = total
fs.mu.Unlock()
metrics.UpdateCacheSize("filesystem", total)
return nil
}
+757
View File
@@ -0,0 +1,757 @@
package filesystem
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/stretchr/testify/suite"
)
type FilesystemStorageTestSuite struct {
suite.Suite
tempDir string
fs *FilesystemStorage
}
func (s *FilesystemStorageTestSuite) SetupTest() {
var err error
s.tempDir, err = os.MkdirTemp("", "gohoarder-test-*")
s.Require().NoError(err)
s.fs, err = New(s.tempDir, 1024*1024) // 1MB quota
s.Require().NoError(err)
}
func (s *FilesystemStorageTestSuite) TearDownTest() {
if s.fs != nil {
s.fs.Close()
}
if s.tempDir != "" {
os.RemoveAll(s.tempDir)
}
}
func TestFilesystemStorageTestSuite(t *testing.T) {
suite.Run(t, new(FilesystemStorageTestSuite))
}
// Test Put operation
func (s *FilesystemStorageTestSuite) TestPut() {
tests := []struct {
name string
key string
data string
opts *storage.PutOptions
expectError bool
errorCheck func(error) bool
}{
{
name: "successful put",
key: "test/file.txt",
data: "hello world",
opts: nil,
expectError: false,
},
{
name: "put with valid MD5 checksum",
key: "test/checksummed.txt",
data: "test data",
opts: &storage.PutOptions{ChecksumMD5: "eb733a00c0c9d336e65691a37ab54293"},
expectError: false,
},
{
name: "put with invalid MD5 checksum",
key: "test/bad-checksum.txt",
data: "test data",
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
expectError: true,
},
{
name: "put with nested path",
key: "deep/nested/path/file.txt",
data: "nested content",
opts: nil,
expectError: false,
},
{
name: "put with path traversal attempt",
key: "../../../etc/passwd",
data: "malicious",
opts: nil,
expectError: false, // Should be sanitized, not error
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
ctx := context.Background()
reader := strings.NewReader(tt.data)
err := s.fs.Put(ctx, tt.key, reader, tt.opts)
if tt.expectError {
s.Error(err)
} else {
s.NoError(err)
// Verify file exists
exists, err := s.fs.Exists(ctx, tt.key)
s.NoError(err)
s.True(exists)
}
})
}
}
// Test Get operation
func (s *FilesystemStorageTestSuite) TestGet() {
ctx := context.Background()
// Setup: Put a test file
testData := "test content for retrieval"
err := s.fs.Put(ctx, "test/get.txt", strings.NewReader(testData), nil)
s.Require().NoError(err)
tests := []struct {
name string
key string
expectError bool
expectData string
}{
{
name: "get existing file",
key: "test/get.txt",
expectError: false,
expectData: testData,
},
{
name: "get non-existent file",
key: "does/not/exist.txt",
expectError: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
reader, err := s.fs.Get(ctx, tt.key)
if tt.expectError {
s.Error(err)
s.Nil(reader)
} else {
s.NoError(err)
s.NotNil(reader)
defer reader.Close()
data, err := io.ReadAll(reader)
s.NoError(err)
s.Equal(tt.expectData, string(data))
}
})
}
}
// Test Delete operation
func (s *FilesystemStorageTestSuite) TestDelete() {
ctx := context.Background()
tests := []struct {
name string
setupKey string
deleteKey string
expectError bool
}{
{
name: "delete existing file",
setupKey: "test/delete-me.txt",
deleteKey: "test/delete-me.txt",
expectError: false,
},
{
name: "delete non-existent file",
setupKey: "",
deleteKey: "does/not/exist.txt",
expectError: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// Setup
if tt.setupKey != "" {
err := s.fs.Put(ctx, tt.setupKey, strings.NewReader("to be deleted"), nil)
s.Require().NoError(err)
}
// Test delete
err := s.fs.Delete(ctx, tt.deleteKey)
if tt.expectError {
s.Error(err)
} else {
s.NoError(err)
// Verify file no longer exists
exists, err := s.fs.Exists(ctx, tt.deleteKey)
s.NoError(err)
s.False(exists)
}
})
}
}
// Test Exists operation
func (s *FilesystemStorageTestSuite) TestExists() {
ctx := context.Background()
// Setup: Put a test file
err := s.fs.Put(ctx, "test/exists.txt", strings.NewReader("content"), nil)
s.Require().NoError(err)
tests := []struct {
name string
key string
exists bool
}{
{
name: "existing file",
key: "test/exists.txt",
exists: true,
},
{
name: "non-existent file",
key: "test/does-not-exist.txt",
exists: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
exists, err := s.fs.Exists(ctx, tt.key)
s.NoError(err)
s.Equal(tt.exists, exists)
})
}
}
// Test List operation
func (s *FilesystemStorageTestSuite) TestList() {
ctx := context.Background()
// Setup: Create multiple files
files := []string{
"packages/npm/react/17.0.1/package.json",
"packages/npm/react/17.0.2/package.json",
"packages/npm/vue/3.0.0/package.json",
"packages/pypi/django/3.2.0/wheel.whl",
}
for _, file := range files {
err := s.fs.Put(ctx, file, strings.NewReader("content"), nil)
s.Require().NoError(err)
}
tests := []struct {
name string
prefix string
opts *storage.ListOptions
expectedCount int
expectedKeys []string
}{
{
name: "list all npm packages",
prefix: "packages/npm",
opts: nil,
expectedCount: 3,
},
{
name: "list react packages",
prefix: "packages/npm/react",
opts: nil,
expectedCount: 2,
},
{
name: "list with pagination",
prefix: "packages/npm",
opts: &storage.ListOptions{MaxResults: 2, Offset: 0},
expectedCount: 2,
},
{
name: "list with offset",
prefix: "packages/npm",
opts: &storage.ListOptions{MaxResults: 2, Offset: 1},
expectedCount: 2,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
objects, err := s.fs.List(ctx, tt.prefix, tt.opts)
s.NoError(err)
s.Equal(tt.expectedCount, len(objects))
// Verify objects have required fields
for _, obj := range objects {
s.NotEmpty(obj.Key)
s.Greater(obj.Size, int64(0))
s.False(obj.Modified.IsZero())
}
})
}
}
// Test Stat operation
func (s *FilesystemStorageTestSuite) TestStat() {
ctx := context.Background()
// Setup: Put a test file
testData := "stat test content"
testKey := "test/stat.txt"
err := s.fs.Put(ctx, testKey, strings.NewReader(testData), nil)
s.Require().NoError(err)
tests := []struct {
name string
key string
expectError bool
}{
{
name: "stat existing file",
key: testKey,
expectError: false,
},
{
name: "stat non-existent file",
key: "does/not/exist.txt",
expectError: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
info, err := s.fs.Stat(ctx, tt.key)
if tt.expectError {
s.Error(err)
s.Nil(info)
} else {
s.NoError(err)
s.NotNil(info)
s.Equal(tt.key, info.Key)
s.Equal(int64(len(testData)), info.Size)
s.False(info.Modified.IsZero())
}
})
}
}
// Test Quota enforcement
func (s *FilesystemStorageTestSuite) TestQuotaEnforcement() {
ctx := context.Background()
// Create a new filesystem with small quota (100 bytes)
smallQuotaDir, err := os.MkdirTemp("", "gohoarder-quota-*")
s.Require().NoError(err)
defer os.RemoveAll(smallQuotaDir)
smallFs, err := New(smallQuotaDir, 100)
s.Require().NoError(err)
defer smallFs.Close()
// First write should succeed
err = smallFs.Put(ctx, "file1.txt", strings.NewReader("small content"), nil)
s.NoError(err)
// Large write should fail due to quota
largeData := strings.Repeat("x", 200)
err = smallFs.Put(ctx, "large.txt", strings.NewReader(largeData), nil)
s.Error(err)
// Verify quota info
quotaInfo, err := smallFs.GetQuota(ctx)
s.NoError(err)
s.Equal(int64(100), quotaInfo.Limit)
s.Greater(quotaInfo.Used, int64(0))
s.LessOrEqual(quotaInfo.Used, quotaInfo.Limit)
}
// Test GetQuota operation
func (s *FilesystemStorageTestSuite) TestGetQuota() {
ctx := context.Background()
// Put some files
err := s.fs.Put(ctx, "file1.txt", strings.NewReader("content1"), nil)
s.Require().NoError(err)
err = s.fs.Put(ctx, "file2.txt", strings.NewReader("content2"), nil)
s.Require().NoError(err)
quotaInfo, err := s.fs.GetQuota(ctx)
s.NoError(err)
s.NotNil(quotaInfo)
s.Equal(int64(1024*1024), quotaInfo.Limit)
s.Greater(quotaInfo.Used, int64(0))
s.Greater(quotaInfo.Available, int64(0))
s.Equal(quotaInfo.Limit, quotaInfo.Used+quotaInfo.Available)
}
// Test Health check
func (s *FilesystemStorageTestSuite) TestHealth() {
ctx := context.Background()
// Healthy filesystem
err := s.fs.Health(ctx)
s.NoError(err)
// Unhealthy filesystem (removed directory)
badDir := filepath.Join(s.tempDir, "nonexistent")
badFs := &FilesystemStorage{basePath: badDir}
err = badFs.Health(ctx)
s.Error(err)
}
// Test Context cancellation
func (s *FilesystemStorageTestSuite) TestContextCancellation() {
// Create cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel()
tests := []struct {
name string
fn func() error
}{
{
name: "Get with cancelled context",
fn: func() error {
_, err := s.fs.Get(ctx, "test.txt")
return err
},
},
{
name: "Put with cancelled context",
fn: func() error {
return s.fs.Put(ctx, "test.txt", strings.NewReader("data"), nil)
},
},
{
name: "Delete with cancelled context",
fn: func() error {
return s.fs.Delete(ctx, "test.txt")
},
},
{
name: "Exists with cancelled context",
fn: func() error {
_, err := s.fs.Exists(ctx, "test.txt")
return err
},
},
{
name: "List with cancelled context",
fn: func() error {
_, err := s.fs.List(ctx, "test", nil)
return err
},
},
{
name: "Stat with cancelled context",
fn: func() error {
_, err := s.fs.Stat(ctx, "test.txt")
return err
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
err := tt.fn()
s.Error(err)
s.Equal(context.Canceled, err)
})
}
}
// Test concurrent access (race condition testing)
func (s *FilesystemStorageTestSuite) TestConcurrentAccess() {
ctx := context.Background()
numGoroutines := 10
numOperations := 100
var wg sync.WaitGroup
// Concurrent writes
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
key := fmt.Sprintf("concurrent/%d/%d.txt", id, j)
data := fmt.Sprintf("data-%d-%d", id, j)
err := s.fs.Put(ctx, key, strings.NewReader(data), nil)
s.NoError(err)
}
}(i)
}
wg.Wait()
// Verify all files exist
objects, err := s.fs.List(ctx, "concurrent", nil)
s.NoError(err)
s.Equal(numGoroutines*numOperations, len(objects))
}
// Test concurrent reads and writes
func (s *FilesystemStorageTestSuite) TestConcurrentReadsAndWrites() {
ctx := context.Background()
// Setup: Create some initial files
for i := 0; i < 10; i++ {
key := fmt.Sprintf("shared/file-%d.txt", i)
err := s.fs.Put(ctx, key, strings.NewReader(fmt.Sprintf("initial-%d", i)), nil)
s.Require().NoError(err)
}
var wg sync.WaitGroup
numReaders := 5
numWriters := 5
numOps := 50
// Concurrent readers
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOps; j++ {
key := fmt.Sprintf("shared/file-%d.txt", j%10)
reader, err := s.fs.Get(ctx, key)
if err == nil {
io.ReadAll(reader)
reader.Close()
}
}
}(i)
}
// Concurrent writers
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOps; j++ {
key := fmt.Sprintf("shared/writer-%d-%d.txt", id, j)
data := fmt.Sprintf("writer-%d-%d", id, j)
s.fs.Put(ctx, key, strings.NewReader(data), nil)
}
}(i)
}
wg.Wait()
// Verify quota tracking is consistent
quotaInfo, err := s.fs.GetQuota(ctx)
s.NoError(err)
s.Greater(quotaInfo.Used, int64(0))
}
// Test Delete updates quota correctly
func (s *FilesystemStorageTestSuite) TestDeleteUpdatesQuota() {
ctx := context.Background()
// Put a file
testData := "test data for quota tracking"
err := s.fs.Put(ctx, "quota/test.txt", strings.NewReader(testData), nil)
s.Require().NoError(err)
// Get quota before delete
quotaBefore, err := s.fs.GetQuota(ctx)
s.Require().NoError(err)
// Delete the file
err = s.fs.Delete(ctx, "quota/test.txt")
s.NoError(err)
// Get quota after delete
quotaAfter, err := s.fs.GetQuota(ctx)
s.NoError(err)
// Quota should have decreased
s.Less(quotaAfter.Used, quotaBefore.Used)
}
// Test atomic write behavior
func (s *FilesystemStorageTestSuite) TestAtomicWrite() {
ctx := context.Background()
key := "atomic/test.txt"
// Initial write
err := s.fs.Put(ctx, key, strings.NewReader("initial"), nil)
s.Require().NoError(err)
// Concurrent readers should never see partial writes
var wg sync.WaitGroup
stopReading := make(chan struct{})
readErrors := make(chan error, 100)
// Start readers
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopReading:
return
default:
reader, err := s.fs.Get(ctx, key)
if err != nil {
readErrors <- err
continue
}
data, err := io.ReadAll(reader)
reader.Close()
if err != nil {
readErrors <- err
continue
}
// Data should be either "initial" or "updated", never partial
content := string(data)
if content != "initial" && content != "updated" {
readErrors <- fmt.Errorf("read partial data: %s", content)
}
}
}
}()
}
// Perform update
time.Sleep(10 * time.Millisecond)
err = s.fs.Put(ctx, key, strings.NewReader("updated"), nil)
s.NoError(err)
// Stop readers
time.Sleep(10 * time.Millisecond)
close(stopReading)
wg.Wait()
close(readErrors)
// Check for read errors
for err := range readErrors {
s.NoError(err)
}
}
// Test path sanitization
func (s *FilesystemStorageTestSuite) TestPathSanitization() {
ctx := context.Background()
maliciousPaths := []string{
"../../../etc/passwd",
"/../secret.txt",
"./../../outside.txt",
"//etc/passwd",
}
for _, path := range maliciousPaths {
s.Run(fmt.Sprintf("sanitize_%s", path), func() {
err := s.fs.Put(ctx, path, strings.NewReader("malicious"), nil)
s.NoError(err) // Should succeed but sanitize path
// Verify file is inside base directory
sanitized := s.fs.keyToPath(path)
s.True(strings.HasPrefix(sanitized, s.tempDir),
"Sanitized path %s should be inside %s", sanitized, s.tempDir)
})
}
}
// Test checksum validation
func (s *FilesystemStorageTestSuite) TestChecksumValidation() {
ctx := context.Background()
testData := "checksum test data"
// Correct checksums calculated for "checksum test data"
correctMD5 := "7dd7323e8ce3e087972f93d3711ef62b"
tests := []struct {
name string
opts *storage.PutOptions
expectError bool
}{
{
name: "valid MD5",
opts: &storage.PutOptions{ChecksumMD5: correctMD5},
expectError: false,
},
{
name: "invalid MD5",
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
expectError: true,
},
{
name: "empty checksum (no validation)",
opts: &storage.PutOptions{ChecksumMD5: ""},
expectError: false,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
key := fmt.Sprintf("checksum/%s.txt", tt.name)
err := s.fs.Put(ctx, key, strings.NewReader(testData), tt.opts)
if tt.expectError {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}
// Benchmark Put operation
func BenchmarkFilesystemPut(b *testing.B) {
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
defer os.RemoveAll(tempDir)
fs, _ := New(tempDir, 1024*1024*1024) // 1GB quota
defer fs.Close()
ctx := context.Background()
data := strings.Repeat("x", 1024) // 1KB
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("bench/file-%d.txt", i)
fs.Put(ctx, key, strings.NewReader(data), nil)
}
}
// Benchmark Get operation
func BenchmarkFilesystemGet(b *testing.B) {
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
defer os.RemoveAll(tempDir)
fs, _ := New(tempDir, 1024*1024*1024)
defer fs.Close()
ctx := context.Background()
data := strings.Repeat("x", 1024)
// Setup: Create test file
fs.Put(ctx, "bench/test.txt", strings.NewReader(data), nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader, _ := fs.Get(ctx, "bench/test.txt")
if reader != nil {
io.ReadAll(reader)
reader.Close()
}
}
}
+91
View File
@@ -0,0 +1,91 @@
package storage
import (
"context"
"io"
"time"
)
// StorageBackend defines the interface for package storage
type StorageBackend interface {
// Get retrieves a package by key
Get(ctx context.Context, key string) (io.ReadCloser, error)
// Put stores a package
Put(ctx context.Context, key string, data io.Reader, opts *PutOptions) error
// Delete removes a package
Delete(ctx context.Context, key string) error
// Exists checks if a package exists
Exists(ctx context.Context, key string) (bool, error)
// List lists packages with prefix
List(ctx context.Context, prefix string, opts *ListOptions) ([]StorageObject, error)
// Stat gets package metadata
Stat(ctx context.Context, key string) (*StorageInfo, error)
// GetQuota returns quota information
GetQuota(ctx context.Context) (*QuotaInfo, error)
// Health checks backend health
Health(ctx context.Context) error
// Close closes the backend
Close() error
}
// PutOptions contains options for Put operations
type PutOptions struct {
ContentType string
Metadata map[string]string
ChecksumMD5 string
ChecksumSHA256 string
}
// ListOptions contains options for List operations
type ListOptions struct {
MaxResults int
Offset int
}
// StorageObject represents a stored object
type StorageObject struct {
Key string
Size int64
Modified time.Time
ETag string
}
// StorageInfo contains detailed object information
type StorageInfo struct {
Key string
Size int64
Modified time.Time
ETag string
ContentType string
Metadata map[string]string
Checksums *Checksums
}
// Checksums contains file checksums
type Checksums struct {
MD5 string
SHA256 string
}
// QuotaInfo contains quota information
type QuotaInfo struct {
Used int64
Available int64
Limit int64
}
// LocalPathProvider is an optional interface that storage backends can implement
// to provide direct file system paths for scanning without creating temp copies
type LocalPathProvider interface {
// GetLocalPath returns the local filesystem path for a storage key
// Returns empty string if the backend doesn't support local paths (e.g., S3, SMB)
GetLocalPath(ctx context.Context, key string) (string, error)
}
+443
View File
@@ -0,0 +1,443 @@
package s3
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
stderrors "errors"
"fmt"
"io"
"strings"
"sync"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/rs/zerolog/log"
)
// S3Storage implements storage.StorageBackend for AWS S3
type S3Storage struct {
client *s3.Client
bucket string
prefix string
quota int64
mu sync.RWMutex
used int64
}
// Config holds S3 configuration
type Config struct {
Bucket string
Region string
Endpoint string // For S3-compatible services (MinIO, etc.)
AccessKeyID string
SecretAccessKey string
Prefix string // Optional prefix for all keys
Quota int64 // Quota in bytes (0 = unlimited)
ForcePathStyle bool // For S3-compatible services
}
// New creates a new S3 storage backend
func New(ctx context.Context, cfg Config) (*S3Storage, error) {
if cfg.Bucket == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 bucket is required")
}
if cfg.Region == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 region is required")
}
// Build AWS config
var awsCfg aws.Config
var err error
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
// Use static credentials
awsCfg, err = config.LoadDefaultConfig(ctx,
config.WithRegion(cfg.Region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
cfg.AccessKeyID,
cfg.SecretAccessKey,
"",
)),
)
} else {
// Use default credential chain
awsCfg, err = config.LoadDefaultConfig(ctx,
config.WithRegion(cfg.Region),
)
}
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to load AWS config")
}
// Create S3 client
var s3Options []func(*s3.Options)
if cfg.Endpoint != "" {
s3Options = append(s3Options, func(o *s3.Options) {
o.BaseEndpoint = aws.String(cfg.Endpoint)
o.UsePathStyle = cfg.ForcePathStyle
})
}
client := s3.NewFromConfig(awsCfg, s3Options...)
s3Storage := &S3Storage{
client: client,
bucket: cfg.Bucket,
prefix: strings.TrimSuffix(cfg.Prefix, "/"),
quota: cfg.Quota,
}
// Calculate initial usage
if err := s3Storage.calculateUsage(ctx); err != nil {
log.Warn().Err(err).Msg("Failed to calculate initial S3 storage usage")
}
return s3Storage, nil
}
// Get retrieves a file from S3
func (s *S3Storage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
s3Key := s.buildKey(key)
input := &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s3Key),
}
result, err := s.client.GetObject(ctx, input)
if err != nil {
if isNotFoundError(err) {
metrics.RecordStorageOperation("s3", "get", "not_found")
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("s3", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get object from S3")
}
metrics.RecordStorageOperation("s3", "get", "success")
return result.Body, nil
}
// Put stores a file in S3
func (s *S3Storage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
s3Key := s.buildKey(key)
// Read data into buffer to calculate checksums and size
var buf bytes.Buffer
md5Hash := md5.New()
sha256Hash := sha256.New()
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
written, err := io.Copy(multiWriter, data)
if err != nil {
metrics.RecordStorageOperation("s3", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
}
// Check quota before upload
if s.quota > 0 {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
if used+written > s.quota {
metrics.RecordStorageOperation("s3", "put", "quota_exceeded")
return errors.QuotaExceeded(s.quota)
}
}
// Verify checksums if provided
if opts != nil {
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
metrics.RecordStorageOperation("s3", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
}
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
metrics.RecordStorageOperation("s3", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
}
}
// Prepare metadata
metadata := make(map[string]string)
if opts != nil && opts.Metadata != nil {
metadata = opts.Metadata
}
// Build put input
input := &s3.PutObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s3Key),
Body: bytes.NewReader(buf.Bytes()),
Metadata: metadata,
}
if opts != nil && opts.ContentType != "" {
input.ContentType = aws.String(opts.ContentType)
}
// Upload to S3
_, err = s.client.PutObject(ctx, input)
if err != nil {
metrics.RecordStorageOperation("s3", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to upload to S3")
}
// Update usage
s.mu.Lock()
s.used += written
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("s3", "put", "success")
metrics.UpdateCacheSize("s3", currentUsed)
return nil
}
// Delete removes a file from S3
func (s *S3Storage) Delete(ctx context.Context, key string) error {
s3Key := s.buildKey(key)
// Get size before deletion for quota tracking
statInfo, err := s.Stat(ctx, key)
if err != nil {
return err
}
input := &s3.DeleteObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s3Key),
}
_, err = s.client.DeleteObject(ctx, input)
if err != nil {
metrics.RecordStorageOperation("s3", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete from S3")
}
// Update usage
s.mu.Lock()
s.used -= statInfo.Size
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("s3", "delete", "success")
metrics.UpdateCacheSize("s3", currentUsed)
return nil
}
// Exists checks if a file exists in S3
func (s *S3Storage) Exists(ctx context.Context, key string) (bool, error) {
s3Key := s.buildKey(key)
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s3Key),
}
_, err := s.client.HeadObject(ctx, input)
if err != nil {
if isNotFoundError(err) {
return false, nil
}
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence in S3")
}
return true, nil
}
// List lists files with prefix in S3
func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
s3Prefix := s.buildKey(prefix)
input := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
Prefix: aws.String(s3Prefix),
}
var objects []storage.StorageObject
paginator := s3.NewListObjectsV2Paginator(s.client, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list objects in S3")
}
for _, obj := range page.Contents {
key := s.stripPrefix(*obj.Key)
objects = append(objects, storage.StorageObject{
Key: key,
Size: *obj.Size,
Modified: *obj.LastModified,
ETag: strings.Trim(*obj.ETag, "\""),
})
}
}
// Apply pagination if requested
if opts != nil {
start := opts.Offset
end := len(objects)
if opts.MaxResults > 0 && start+opts.MaxResults < end {
end = start + opts.MaxResults
}
if start < len(objects) {
objects = objects[start:end]
} else {
objects = []storage.StorageObject{}
}
}
return objects, nil
}
// Stat gets file metadata from S3
func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
s3Key := s.buildKey(key)
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s3Key),
}
result, err := s.client.HeadObject(ctx, input)
if err != nil {
if isNotFoundError(err) {
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat object in S3")
}
info := &storage.StorageInfo{
Key: key,
Size: *result.ContentLength,
Modified: *result.LastModified,
ETag: strings.Trim(*result.ETag, "\""),
Metadata: result.Metadata,
}
if result.ContentType != nil {
info.ContentType = *result.ContentType
}
return info, nil
}
// GetQuota returns quota information
func (s *S3Storage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
available := s.quota - used
if available < 0 {
available = 0
}
return &storage.QuotaInfo{
Used: used,
Available: available,
Limit: s.quota,
}, nil
}
// Health checks S3 health
func (s *S3Storage) Health(ctx context.Context) error {
// Try to list bucket to verify connectivity
input := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
MaxKeys: aws.Int32(1),
}
_, err := s.client.ListObjectsV2(ctx, input)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "S3 health check failed")
}
return nil
}
// Close closes the storage backend
func (s *S3Storage) Close() error {
// No cleanup needed for S3 client
return nil
}
// buildKey builds the full S3 key with prefix
func (s *S3Storage) buildKey(key string) string {
key = strings.TrimPrefix(key, "/")
if s.prefix != "" {
return s.prefix + "/" + key
}
return key
}
// stripPrefix removes the configured prefix from an S3 key
func (s *S3Storage) stripPrefix(s3Key string) string {
if s.prefix != "" {
return strings.TrimPrefix(s3Key, s.prefix+"/")
}
return s3Key
}
// calculateUsage calculates current S3 storage usage
func (s *S3Storage) calculateUsage(ctx context.Context) error {
var total int64
input := &s3.ListObjectsV2Input{
Bucket: aws.String(s.bucket),
}
if s.prefix != "" {
input.Prefix = aws.String(s.prefix + "/")
}
paginator := s3.NewListObjectsV2Paginator(s.client, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return err
}
for _, obj := range page.Contents {
total += *obj.Size
}
}
s.mu.Lock()
s.used = total
s.mu.Unlock()
metrics.UpdateCacheSize("s3", total)
return nil
}
// isNotFoundError checks if an error is a "not found" error
func isNotFoundError(err error) bool {
if err == nil {
return false
}
var notFound *types.NotFound
var noSuchKey *types.NoSuchKey
return stderrors.As(err, &notFound) || stderrors.As(err, &noSuchKey)
}
+579
View File
@@ -0,0 +1,579 @@
package smb
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/hirochachacha/go-smb2"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/rs/zerolog/log"
)
// SMBStorage implements storage.StorageBackend for SMB/CIFS shares
type SMBStorage struct {
host string
share string
basePath string
username string
password string
quota int64
mu sync.RWMutex
used int64
connPool chan *smbConnection
poolSize int
}
// smbConnection wraps an SMB session and share
type smbConnection struct {
conn net.Conn
session *smb2.Session
share *smb2.Share
lastUse time.Time
}
// Config holds SMB configuration
type Config struct {
Host string // SMB server hostname or IP
Port int // SMB server port (default: 445)
Share string // SMB share name
BasePath string // Base path within the share
Username string // SMB username
Password string // SMB password
Domain string // SMB domain (optional)
Quota int64 // Quota in bytes (0 = unlimited)
PoolSize int // Connection pool size (default: 5)
}
// New creates a new SMB storage backend
func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
if cfg.Host == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required")
}
if cfg.Share == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required")
}
if cfg.Port == 0 {
cfg.Port = 445 // Default SMB port
}
if cfg.PoolSize == 0 {
cfg.PoolSize = 5 // Default pool size
}
smbStorage := &SMBStorage{
host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
share: cfg.Share,
basePath: strings.Trim(cfg.BasePath, "/\\"),
username: cfg.Username,
password: cfg.Password,
quota: cfg.Quota,
connPool: make(chan *smbConnection, cfg.PoolSize),
poolSize: cfg.PoolSize,
}
// Initialize connection pool
for i := 0; i < cfg.PoolSize; i++ {
conn, err := smbStorage.createConnection(ctx)
if err != nil {
// Clean up any created connections
close(smbStorage.connPool)
for c := range smbStorage.connPool {
c.close()
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool")
}
smbStorage.connPool <- conn
}
// Calculate initial usage
if err := smbStorage.calculateUsage(ctx); err != nil {
log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage")
}
return smbStorage, nil
}
// createConnection creates a new SMB connection
func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) {
conn, err := net.Dial("tcp", s.host)
if err != nil {
return nil, err
}
dialer := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: s.username,
Password: s.password,
},
}
session, err := dialer.Dial(conn)
if err != nil {
conn.Close()
return nil, err
}
share, err := session.Mount(s.share)
if err != nil {
session.Logoff()
conn.Close()
return nil, err
}
return &smbConnection{
conn: conn,
session: session,
share: share,
lastUse: time.Now(),
}, nil
}
// getConnection gets a connection from the pool
func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) {
select {
case conn := <-s.connPool:
conn.lastUse = time.Now()
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(30 * time.Second):
return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection")
}
}
// returnConnection returns a connection to the pool
func (s *SMBStorage) returnConnection(conn *smbConnection) {
select {
case s.connPool <- conn:
default:
// Pool is full, close the connection
conn.close()
}
}
// close closes an SMB connection
func (c *smbConnection) close() {
if c.share != nil {
c.share.Umount()
}
if c.session != nil {
c.session.Logoff()
}
if c.conn != nil {
c.conn.Close()
}
}
// Get retrieves a file from SMB share
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
path := s.keyToPath(key)
// Open file
file, err := conn.share.Open(path)
if err != nil {
s.returnConnection(conn)
if os.IsNotExist(err) {
metrics.RecordStorageOperation("smb", "get", "not_found")
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("smb", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
}
// Read entire file into memory and close SMB connection
// This is necessary because we need to return the connection to the pool
data, err := io.ReadAll(file)
file.Close()
s.returnConnection(conn)
if err != nil {
metrics.RecordStorageOperation("smb", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
}
metrics.RecordStorageOperation("smb", "get", "success")
return io.NopCloser(bytes.NewReader(data)), nil
}
// Put stores a file on SMB share
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
dir := filepath.Dir(path)
// Create directory structure
if err := conn.share.MkdirAll(dir, 0755); err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
}
// Read data into buffer to calculate checksums and size
var buf bytes.Buffer
md5Hash := md5.New()
sha256Hash := sha256.New()
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
written, err := io.Copy(multiWriter, data)
if err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
}
// Check quota
if s.quota > 0 {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
if used+written > s.quota {
metrics.RecordStorageOperation("smb", "put", "quota_exceeded")
return errors.QuotaExceeded(s.quota)
}
}
// Verify checksums if provided
if opts != nil {
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
metrics.RecordStorageOperation("smb", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
}
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
metrics.RecordStorageOperation("smb", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
}
}
// Create temp file for atomic write
tempPath := path + ".tmp"
file, err := conn.share.Create(tempPath)
if err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file")
}
// Write data
_, err = io.Copy(file, bytes.NewReader(buf.Bytes()))
file.Close()
if err != nil {
conn.share.Remove(tempPath)
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
}
// Atomic rename
if err := conn.share.Rename(tempPath, path); err != nil {
conn.share.Remove(tempPath)
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file")
}
// Update usage
s.mu.Lock()
s.used += written
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("smb", "put", "success")
metrics.UpdateCacheSize("smb", currentUsed)
return nil
}
// Delete removes a file from SMB share
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
// Get size before deletion
info, err := conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
metrics.RecordStorageOperation("smb", "delete", "not_found")
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("smb", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
}
size := info.Size()
if err := conn.share.Remove(path); err != nil {
metrics.RecordStorageOperation("smb", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
}
// Update usage
s.mu.Lock()
s.used -= size
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("smb", "delete", "success")
metrics.UpdateCacheSize("smb", currentUsed)
return nil
}
// Exists checks if a file exists on SMB share
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return false, err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
_, err = conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check SMB file existence")
}
return true, nil
}
// List lists files with prefix on SMB share
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
defer s.returnConnection(conn)
searchPath := s.keyToPath(prefix)
var objects []storage.StorageObject
err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error {
if info.IsDir() {
return nil
}
key := s.pathToKey(path)
objects = append(objects, storage.StorageObject{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
})
return nil
})
if err != nil && !os.IsNotExist(err) {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
}
// Apply pagination if requested
if opts != nil {
start := opts.Offset
end := len(objects)
if opts.MaxResults > 0 && start+opts.MaxResults < end {
end = start + opts.MaxResults
}
if start < len(objects) {
objects = objects[start:end]
} else {
objects = []storage.StorageObject{}
}
}
return objects, nil
}
// Stat gets file metadata from SMB share
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
info, err := conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
}
return &storage.StorageInfo{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
}, nil
}
// GetQuota returns quota information
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
available := s.quota - used
if available < 0 {
available = 0
}
return &storage.QuotaInfo{
Used: used,
Available: available,
Limit: s.quota,
}, nil
}
// Health checks SMB health
func (s *SMBStorage) Health(ctx context.Context) error {
conn, err := s.getConnection(ctx)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error")
}
defer s.returnConnection(conn)
// Try to stat the base path
path := s.keyToPath("")
_, err = conn.share.Stat(path)
if err != nil && !os.IsNotExist(err) {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
}
return nil
}
// Close closes the storage backend
func (s *SMBStorage) Close() error {
close(s.connPool)
for conn := range s.connPool {
conn.close()
}
return nil
}
// keyToPath converts a storage key to SMB path
func (s *SMBStorage) keyToPath(key string) string {
key = strings.TrimPrefix(key, "/")
key = filepath.Clean(key)
// Remove path traversal attempts
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = strings.TrimPrefix(key, "../")
key = strings.TrimPrefix(key, "..\\")
}
key = filepath.Clean(key)
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = ""
}
if s.basePath != "" {
return filepath.Join(s.basePath, key)
}
return key
}
// pathToKey converts an SMB path back to a storage key
func (s *SMBStorage) pathToKey(path string) string {
if s.basePath != "" {
path = strings.TrimPrefix(path, s.basePath)
path = strings.TrimPrefix(path, "/")
path = strings.TrimPrefix(path, "\\")
}
return filepath.ToSlash(path)
}
// walkPath recursively walks an SMB directory
func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error {
info, err := share.Stat(path)
if err != nil {
return err
}
if !info.IsDir() {
return fn(path, info)
}
entries, err := share.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
entryPath := filepath.Join(path, entry.Name())
if entry.IsDir() {
if err := s.walkPath(share, entryPath, fn); err != nil {
return err
}
} else {
if err := fn(entryPath, entry); err != nil {
return err
}
}
}
return nil
}
// calculateUsage calculates current SMB storage usage
func (s *SMBStorage) calculateUsage(ctx context.Context) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
var total int64
basePath := s.keyToPath("")
err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error {
if !info.IsDir() {
total += info.Size()
}
return nil
})
if err != nil && !os.IsNotExist(err) {
return err
}
s.mu.Lock()
s.used = total
s.mu.Unlock()
metrics.UpdateCacheSize("smb", total)
return nil
}
+30
View File
@@ -0,0 +1,30 @@
package uuid
import (
"crypto/rand"
"fmt"
)
// UUID represents a UUID (RFC 4122)
type UUID [16]byte
// New generates a random UUID v4
func New() UUID {
var u UUID
// Read random bytes
if _, err := rand.Read(u[:]); err != nil {
panic(fmt.Sprintf("failed to generate UUID: %v", err))
}
// Set version (4) and variant (RFC 4122)
u[6] = (u[6] & 0x0f) | 0x40 // Version 4
u[8] = (u[8] & 0x3f) | 0x80 // Variant RFC 4122
return u
}
// String returns the UUID in standard format (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)
func (u UUID) String() string {
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
u[0:4], u[4:6], u[6:8], u[8:10], u[10:16])
}
+217
View File
@@ -0,0 +1,217 @@
package uuid
import (
"regexp"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestNew tests UUID generation
func TestNew(t *testing.T) {
tests := []struct {
name string
runs int
}{
{
name: "generate single UUID",
runs: 1,
},
{
name: "generate multiple UUIDs",
runs: 100,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
seen := make(map[string]bool)
for i := 0; i < tt.runs; i++ {
uuid := New()
// Verify UUID is 16 bytes
assert.Equal(t, 16, len(uuid))
// Verify version is 4
version := (uuid[6] >> 4) & 0x0f
assert.Equal(t, uint8(4), version, "UUID version should be 4")
// Verify variant is RFC 4122
variant := (uuid[8] >> 6) & 0x03
assert.Equal(t, uint8(2), variant, "UUID variant should be RFC 4122 (10 in binary)")
// Check uniqueness
str := uuid.String()
assert.False(t, seen[str], "UUID should be unique")
seen[str] = true
}
})
}
}
// TestString tests UUID string formatting
func TestString(t *testing.T) {
tests := []struct {
name string
uuid UUID
expected string
}{
{
name: "zero UUID",
uuid: UUID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
expected: "00000000-0000-0000-0000-000000000000",
},
{
name: "all ones UUID",
uuid: UUID{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
expected: "ffffffff-ffff-ffff-ffff-ffffffffffff",
},
{
name: "mixed values UUID",
uuid: UUID{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88},
expected: "12345678-9abc-def0-1122-334455667788",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
str := tt.uuid.String()
assert.Equal(t, tt.expected, str)
// Verify format matches UUID regex
matched, err := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, str)
require.NoError(t, err)
assert.True(t, matched, "UUID string should match standard format")
// Verify dashes are in correct positions
assert.Equal(t, "-", string(str[8]))
assert.Equal(t, "-", string(str[13]))
assert.Equal(t, "-", string(str[18]))
assert.Equal(t, "-", string(str[23]))
// Verify length
assert.Equal(t, 36, len(str))
})
}
}
// TestUUIDFormat tests that generated UUIDs match the standard format
func TestUUIDFormat(t *testing.T) {
const iterations = 1000
// Compile regex once for performance
hexPattern := regexp.MustCompile(`^[0-9a-f]+$`)
for i := 0; i < iterations; i++ {
uuid := New()
str := uuid.String()
// Test standard UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
parts := strings.Split(str, "-")
require.Len(t, parts, 5, "UUID should have 5 parts separated by dashes")
assert.Equal(t, 8, len(parts[0]), "First part should be 8 characters")
assert.Equal(t, 4, len(parts[1]), "Second part should be 4 characters")
assert.Equal(t, 4, len(parts[2]), "Third part should be 4 characters")
assert.Equal(t, 4, len(parts[3]), "Fourth part should be 4 characters")
assert.Equal(t, 12, len(parts[4]), "Fifth part should be 12 characters")
// Verify all characters are hexadecimal
for _, part := range parts {
assert.True(t, hexPattern.MatchString(part), "UUID parts should only contain hex characters")
}
// Verify version bits (4th character of third part should start with 4)
versionChar := parts[2][0]
assert.Equal(t, byte('4'), versionChar, "UUID version should be 4")
// Verify variant bits (first character of fourth part should be 8, 9, a, or b)
variantChar := parts[3][0]
assert.Contains(t, []byte{'8', '9', 'a', 'b'}, variantChar, "UUID variant should be RFC 4122")
}
}
// TestConcurrentGeneration tests that UUID generation is safe for concurrent use
func TestConcurrentGeneration(t *testing.T) {
const numGoroutines = 100
const uuidsPerGoroutine = 100
results := make(chan UUID, numGoroutines*uuidsPerGoroutine)
// Generate UUIDs concurrently
for i := 0; i < numGoroutines; i++ {
go func() {
for j := 0; j < uuidsPerGoroutine; j++ {
results <- New()
}
}()
}
// Collect all UUIDs
seen := make(map[string]bool)
for i := 0; i < numGoroutines*uuidsPerGoroutine; i++ {
uuid := <-results
str := uuid.String()
// Verify uniqueness
assert.False(t, seen[str], "UUID should be unique even in concurrent generation")
seen[str] = true
// Verify version and variant
version := (uuid[6] >> 4) & 0x0f
assert.Equal(t, uint8(4), version)
variant := (uuid[8] >> 6) & 0x03
assert.Equal(t, uint8(2), variant)
}
// Verify we got all expected UUIDs
assert.Equal(t, numGoroutines*uuidsPerGoroutine, len(seen))
}
// TestUUIDEquality tests UUID equality
func TestUUIDEquality(t *testing.T) {
uuid1 := New()
uuid2 := New()
// Different UUIDs should not be equal
assert.NotEqual(t, uuid1, uuid2)
assert.NotEqual(t, uuid1.String(), uuid2.String())
// Same UUID should be equal
uuid3 := uuid1
assert.Equal(t, uuid1, uuid3)
assert.Equal(t, uuid1.String(), uuid3.String())
}
// TestUUIDArrayAccess tests that UUID can be accessed as a byte array
func TestUUIDArrayAccess(t *testing.T) {
uuid := New()
// Verify we can access all bytes
for i := 0; i < 16; i++ {
_ = uuid[i]
}
// Verify length
assert.Equal(t, 16, len(uuid))
}
// BenchmarkNew benchmarks UUID generation
func BenchmarkNew(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = New()
}
}
// BenchmarkString benchmarks UUID string conversion
func BenchmarkString(b *testing.B) {
uuid := New()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = uuid.String()
}
}
+388
View File
@@ -0,0 +1,388 @@
package websocket
import (
"context"
"encoding/json"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
)
// EventType represents the type of event being broadcast
type EventType string
const (
EventPackageCached EventType = "package_cached"
EventPackageDeleted EventType = "package_deleted"
EventPackageDownloaded EventType = "package_downloaded"
EventScanComplete EventType = "scan_complete"
EventStatsUpdate EventType = "stats_update"
EventSystemAlert EventType = "system_alert"
)
// Event represents a WebSocket event message
type Event struct {
Type EventType `json:"type"`
Timestamp time.Time `json:"timestamp"`
Data map[string]interface{} `json:"data"`
}
// Client represents a WebSocket client connection
type Client struct {
conn *websocket.Conn
send chan []byte
server *Server
subscriptions map[EventType]bool
mu sync.RWMutex
}
// Server manages WebSocket connections and event broadcasting
type Server struct {
clients map[*Client]bool
broadcast chan Event
register chan *Client
unregister chan *Client
mu sync.RWMutex
upgrader websocket.Upgrader
}
// Config holds WebSocket server configuration
type Config struct {
ReadBufferSize int
WriteBufferSize int
CheckOrigin func(r *http.Request) bool
}
// NewServer creates a new WebSocket server
func NewServer(cfg Config) *Server {
if cfg.CheckOrigin == nil {
cfg.CheckOrigin = func(r *http.Request) bool {
return true // Allow all origins by default
}
}
server := &Server{
clients: make(map[*Client]bool),
broadcast: make(chan Event, 256),
register: make(chan *Client),
unregister: make(chan *Client),
upgrader: websocket.Upgrader{
ReadBufferSize: cfg.ReadBufferSize,
WriteBufferSize: cfg.WriteBufferSize,
CheckOrigin: cfg.CheckOrigin,
},
}
return server
}
// Start starts the WebSocket server event loop
func (s *Server) Start(ctx context.Context) {
go s.run(ctx)
log.Info().Msg("WebSocket server started")
}
// run handles client registration/unregistration and broadcasting
func (s *Server) run(ctx context.Context) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info().Msg("WebSocket server shutting down")
s.closeAllClients()
return
case client := <-s.register:
s.mu.Lock()
s.clients[client] = true
s.mu.Unlock()
log.Debug().
Int("total_clients", len(s.clients)).
Msg("Client registered")
case client := <-s.unregister:
s.mu.Lock()
if _, ok := s.clients[client]; ok {
delete(s.clients, client)
close(client.send)
}
s.mu.Unlock()
log.Debug().
Int("total_clients", len(s.clients)).
Msg("Client unregistered")
case event := <-s.broadcast:
s.broadcastEvent(event)
case <-ticker.C:
// Ping all clients to keep connections alive
s.pingClients()
}
}
}
// broadcastEvent sends an event to all subscribed clients
func (s *Server) broadcastEvent(event Event) {
message, err := json.Marshal(event)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal event")
return
}
s.mu.RLock()
defer s.mu.RUnlock()
for client := range s.clients {
// Check if client is subscribed to this event type
client.mu.RLock()
subscribed := len(client.subscriptions) == 0 || client.subscriptions[event.Type]
client.mu.RUnlock()
if subscribed {
select {
case client.send <- message:
default:
// Client send buffer full - close connection
go func(c *Client) {
s.unregister <- c
}(client)
}
}
}
log.Debug().
Str("event_type", string(event.Type)).
Int("clients_notified", len(s.clients)).
Msg("Event broadcast")
}
// pingClients sends ping messages to all connected clients
func (s *Server) pingClients() {
s.mu.RLock()
defer s.mu.RUnlock()
for client := range s.clients {
if err := client.conn.WriteControl(
websocket.PingMessage,
[]byte{},
time.Now().Add(10*time.Second),
); err != nil {
log.Debug().Err(err).Msg("Failed to ping client")
go func(c *Client) {
s.unregister <- c
}(client)
}
}
}
// closeAllClients closes all client connections
func (s *Server) closeAllClients() {
s.mu.Lock()
defer s.mu.Unlock()
for client := range s.clients {
client.conn.Close()
close(client.send)
}
s.clients = make(map[*Client]bool)
}
// Broadcast sends an event to all connected clients
func (s *Server) Broadcast(eventType EventType, data map[string]interface{}) {
event := Event{
Type: eventType,
Timestamp: time.Now(),
Data: data,
}
select {
case s.broadcast <- event:
default:
log.Warn().Msg("Broadcast channel full - dropping event")
}
}
// HandleWebSocket upgrades HTTP connection to WebSocket
func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error().Err(err).Msg("Failed to upgrade connection")
return
}
client := &Client{
conn: conn,
send: make(chan []byte, 256),
server: s,
subscriptions: make(map[EventType]bool),
}
s.register <- client
// Start goroutines for reading and writing
go client.readPump()
go client.writePump()
log.Info().
Str("remote_addr", r.RemoteAddr).
Msg("WebSocket connection established")
}
// readPump handles incoming messages from the client
func (c *Client) readPump() {
defer func() {
c.server.unregister <- c
c.conn.Close()
}()
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Error().Err(err).Msg("WebSocket read error")
}
break
}
// Handle client messages (subscriptions, etc.)
c.handleMessage(message)
}
}
// writePump handles outgoing messages to the client
func (c *Client) writePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
// Channel closed
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
// Write any additional queued messages
n := len(c.send)
for i := 0; i < n; i++ {
w.Write([]byte{'\n'})
w.Write(<-c.send)
}
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// handleMessage processes incoming client messages
func (c *Client) handleMessage(message []byte) {
var msg struct {
Action string `json:"action"`
Data interface{} `json:"data"`
}
if err := json.Unmarshal(message, &msg); err != nil {
log.Error().Err(err).Msg("Failed to unmarshal client message")
return
}
switch msg.Action {
case "subscribe":
c.handleSubscribe(msg.Data)
case "unsubscribe":
c.handleUnsubscribe(msg.Data)
case "ping":
c.sendPong()
default:
log.Warn().Str("action", msg.Action).Msg("Unknown client action")
}
}
// handleSubscribe subscribes the client to specific event types
func (c *Client) handleSubscribe(data interface{}) {
eventTypes, ok := data.([]interface{})
if !ok {
log.Error().Msg("Invalid subscribe data format")
return
}
c.mu.Lock()
defer c.mu.Unlock()
for _, et := range eventTypes {
if eventType, ok := et.(string); ok {
c.subscriptions[EventType(eventType)] = true
log.Debug().
Str("event_type", eventType).
Msg("Client subscribed to event type")
}
}
}
// handleUnsubscribe unsubscribes the client from specific event types
func (c *Client) handleUnsubscribe(data interface{}) {
eventTypes, ok := data.([]interface{})
if !ok {
log.Error().Msg("Invalid unsubscribe data format")
return
}
c.mu.Lock()
defer c.mu.Unlock()
for _, et := range eventTypes {
if eventType, ok := et.(string); ok {
delete(c.subscriptions, EventType(eventType))
log.Debug().
Str("event_type", eventType).
Msg("Client unsubscribed from event type")
}
}
}
// sendPong sends a pong response to the client
func (c *Client) sendPong() {
response := map[string]string{"type": "pong"}
message, _ := json.Marshal(response)
select {
case c.send <- message:
default:
}
}
// GetConnectedClients returns the number of connected clients
func (s *Server) GetConnectedClients() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.clients)
}