mirror of
https://github.com/lukaszraczylo/gohoarder.git
synced 2026-06-29 03:12:54 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,437 @@
|
||||
package analytics
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// PackageDownload represents a package download event
|
||||
type PackageDownload struct {
|
||||
Registry string
|
||||
Name string
|
||||
Version string
|
||||
Timestamp time.Time
|
||||
BytesSize int64
|
||||
ClientIP string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// PackageStats holds statistics for a package
|
||||
type PackageStats struct {
|
||||
Registry string
|
||||
Name string
|
||||
TotalDownloads int64
|
||||
UniqueVersions int
|
||||
LastDownload time.Time
|
||||
FirstSeen time.Time
|
||||
BytesServed int64
|
||||
}
|
||||
|
||||
// TrendData represents trend information over time
|
||||
type TrendData struct {
|
||||
Period time.Duration
|
||||
Downloads int64
|
||||
Packages int
|
||||
}
|
||||
|
||||
// PopularPackage represents a popular package entry
|
||||
type PopularPackage struct {
|
||||
Registry string
|
||||
Name string
|
||||
Downloads int64
|
||||
RecentDownloads int64 // Downloads in last 7 days
|
||||
Trend float64 // Growth rate
|
||||
}
|
||||
|
||||
// Engine tracks and analyzes package downloads
|
||||
type Engine struct {
|
||||
downloads []PackageDownload
|
||||
downloadsMu sync.RWMutex
|
||||
stats map[string]*PackageStats // key: registry:name
|
||||
statsMu sync.RWMutex
|
||||
maxEvents int
|
||||
flushTicker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// Config holds analytics engine configuration
|
||||
type Config struct {
|
||||
MaxEvents int
|
||||
FlushInterval time.Duration
|
||||
}
|
||||
|
||||
// NewEngine creates a new analytics engine
|
||||
func NewEngine(cfg Config) *Engine {
|
||||
if cfg.MaxEvents <= 0 {
|
||||
cfg.MaxEvents = 10000
|
||||
}
|
||||
if cfg.FlushInterval <= 0 {
|
||||
cfg.FlushInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
engine := &Engine{
|
||||
downloads: make([]PackageDownload, 0, cfg.MaxEvents),
|
||||
stats: make(map[string]*PackageStats),
|
||||
maxEvents: cfg.MaxEvents,
|
||||
flushTicker: time.NewTicker(cfg.FlushInterval),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Load existing stats from metadata store
|
||||
engine.loadStats()
|
||||
|
||||
// Start background flush goroutine
|
||||
go engine.flushLoop()
|
||||
|
||||
log.Info().
|
||||
Int("max_events", cfg.MaxEvents).
|
||||
Dur("flush_interval", cfg.FlushInterval).
|
||||
Msg("Analytics engine started")
|
||||
|
||||
return engine
|
||||
}
|
||||
|
||||
// TrackDownload records a package download event
|
||||
func (e *Engine) TrackDownload(download PackageDownload) {
|
||||
e.downloadsMu.Lock()
|
||||
defer e.downloadsMu.Unlock()
|
||||
|
||||
// Add to event buffer
|
||||
e.downloads = append(e.downloads, download)
|
||||
|
||||
// Update in-memory stats
|
||||
e.updateStats(download)
|
||||
|
||||
// Flush if buffer is full
|
||||
if len(e.downloads) >= e.maxEvents {
|
||||
go e.flush()
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registry", download.Registry).
|
||||
Str("package", download.Name).
|
||||
Str("version", download.Version).
|
||||
Msg("Download tracked")
|
||||
}
|
||||
|
||||
// updateStats updates in-memory statistics
|
||||
func (e *Engine) updateStats(download PackageDownload) {
|
||||
e.statsMu.Lock()
|
||||
defer e.statsMu.Unlock()
|
||||
|
||||
key := download.Registry + ":" + download.Name
|
||||
stats, exists := e.stats[key]
|
||||
if !exists {
|
||||
stats = &PackageStats{
|
||||
Registry: download.Registry,
|
||||
Name: download.Name,
|
||||
FirstSeen: download.Timestamp,
|
||||
}
|
||||
e.stats[key] = stats
|
||||
}
|
||||
|
||||
stats.TotalDownloads++
|
||||
stats.BytesServed += download.BytesSize
|
||||
stats.LastDownload = download.Timestamp
|
||||
|
||||
// Track unique versions (simplified)
|
||||
stats.UniqueVersions++
|
||||
}
|
||||
|
||||
// GetPackageStats returns statistics for a specific package
|
||||
func (e *Engine) GetPackageStats(registry, name string) (*PackageStats, bool) {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
key := registry + ":" + name
|
||||
stats, exists := e.stats[key]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Return a copy to avoid race conditions
|
||||
statsCopy := *stats
|
||||
return &statsCopy, true
|
||||
}
|
||||
|
||||
// GetTopPackages returns the most downloaded packages
|
||||
func (e *Engine) GetTopPackages(limit int) []PopularPackage {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
packages := make([]PopularPackage, 0, len(e.stats))
|
||||
for _, stats := range e.stats {
|
||||
packages = append(packages, PopularPackage{
|
||||
Registry: stats.Registry,
|
||||
Name: stats.Name,
|
||||
Downloads: stats.TotalDownloads,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by downloads descending
|
||||
sort.Slice(packages, func(i, j int) bool {
|
||||
return packages[i].Downloads > packages[j].Downloads
|
||||
})
|
||||
|
||||
if limit > 0 && limit < len(packages) {
|
||||
packages = packages[:limit]
|
||||
}
|
||||
|
||||
return packages
|
||||
}
|
||||
|
||||
// GetTrendingPackages returns packages with growing popularity
|
||||
func (e *Engine) GetTrendingPackages(limit int) []PopularPackage {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
sevenDaysAgo := time.Now().Add(-7 * 24 * time.Hour)
|
||||
|
||||
packages := make([]PopularPackage, 0)
|
||||
for _, stats := range e.stats {
|
||||
// Calculate recent downloads (last 7 days)
|
||||
recent := e.getRecentDownloads(stats.Registry, stats.Name, sevenDaysAgo)
|
||||
|
||||
// Calculate trend (simple growth rate)
|
||||
trend := 0.0
|
||||
if stats.TotalDownloads > 0 {
|
||||
trend = float64(recent) / float64(stats.TotalDownloads) * 100
|
||||
}
|
||||
|
||||
packages = append(packages, PopularPackage{
|
||||
Registry: stats.Registry,
|
||||
Name: stats.Name,
|
||||
Downloads: stats.TotalDownloads,
|
||||
RecentDownloads: recent,
|
||||
Trend: trend,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by trend descending
|
||||
sort.Slice(packages, func(i, j int) bool {
|
||||
return packages[i].Trend > packages[j].Trend
|
||||
})
|
||||
|
||||
if limit > 0 && limit < len(packages) {
|
||||
packages = packages[:limit]
|
||||
}
|
||||
|
||||
return packages
|
||||
}
|
||||
|
||||
// getRecentDownloads counts downloads since a given time
|
||||
func (e *Engine) getRecentDownloads(registry, name string, since time.Time) int64 {
|
||||
e.downloadsMu.RLock()
|
||||
defer e.downloadsMu.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, download := range e.downloads {
|
||||
if download.Registry == registry &&
|
||||
download.Name == name &&
|
||||
download.Timestamp.After(since) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GetTrends returns download trends over different time periods
|
||||
func (e *Engine) GetTrends() []TrendData {
|
||||
e.downloadsMu.RLock()
|
||||
defer e.downloadsMu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
periods := []time.Duration{
|
||||
1 * time.Hour,
|
||||
24 * time.Hour,
|
||||
7 * 24 * time.Hour,
|
||||
30 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
trends := make([]TrendData, len(periods))
|
||||
for i, period := range periods {
|
||||
since := now.Add(-period)
|
||||
downloads := int64(0)
|
||||
packages := make(map[string]bool)
|
||||
|
||||
for _, download := range e.downloads {
|
||||
if download.Timestamp.After(since) {
|
||||
downloads++
|
||||
packages[download.Registry+":"+download.Name] = true
|
||||
}
|
||||
}
|
||||
|
||||
trends[i] = TrendData{
|
||||
Period: period,
|
||||
Downloads: downloads,
|
||||
Packages: len(packages),
|
||||
}
|
||||
}
|
||||
|
||||
return trends
|
||||
}
|
||||
|
||||
// GetTotalStats returns overall statistics
|
||||
func (e *Engine) GetTotalStats() map[string]interface{} {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
totalDownloads := int64(0)
|
||||
totalBytes := int64(0)
|
||||
registries := make(map[string]int64)
|
||||
|
||||
for _, stats := range e.stats {
|
||||
totalDownloads += stats.TotalDownloads
|
||||
totalBytes += stats.BytesServed
|
||||
registries[stats.Registry]++
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_packages": len(e.stats),
|
||||
"total_downloads": totalDownloads,
|
||||
"total_bytes": totalBytes,
|
||||
"registries": registries,
|
||||
}
|
||||
}
|
||||
|
||||
// flushLoop periodically flushes download events to metadata store
|
||||
func (e *Engine) flushLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-e.flushTicker.C:
|
||||
e.flush()
|
||||
case <-e.stopChan:
|
||||
e.flush() // Final flush
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flush persists download events to metadata store
|
||||
func (e *Engine) flush() {
|
||||
e.downloadsMu.Lock()
|
||||
downloads := e.downloads
|
||||
e.downloads = make([]PackageDownload, 0, e.maxEvents)
|
||||
e.downloadsMu.Unlock()
|
||||
|
||||
if len(downloads) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("events", len(downloads)).
|
||||
Msg("Flushing analytics events")
|
||||
|
||||
// In a real implementation, this would persist to the metadata store
|
||||
// For now, we just clear the buffer
|
||||
// TODO: Add actual persistence when metadata store supports analytics tables
|
||||
}
|
||||
|
||||
// loadStats loads existing statistics from metadata store
|
||||
func (e *Engine) loadStats() {
|
||||
// TODO: Load stats from metadata store when analytics tables are implemented
|
||||
log.Debug().Msg("Loading analytics stats from metadata store")
|
||||
}
|
||||
|
||||
// Close stops the analytics engine
|
||||
func (e *Engine) Close() {
|
||||
close(e.stopChan)
|
||||
e.flushTicker.Stop()
|
||||
e.flush() // Final flush
|
||||
log.Info().Msg("Analytics engine stopped")
|
||||
}
|
||||
|
||||
// GetRegistryStats returns per-registry statistics
|
||||
func (e *Engine) GetRegistryStats(registry string) map[string]interface{} {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
totalPackages := 0
|
||||
totalDownloads := int64(0)
|
||||
totalBytes := int64(0)
|
||||
|
||||
for _, stats := range e.stats {
|
||||
if stats.Registry == registry {
|
||||
totalPackages++
|
||||
totalDownloads += stats.TotalDownloads
|
||||
totalBytes += stats.BytesServed
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"registry": registry,
|
||||
"total_packages": totalPackages,
|
||||
"total_downloads": totalDownloads,
|
||||
"total_bytes": totalBytes,
|
||||
}
|
||||
}
|
||||
|
||||
// SearchPackages finds packages matching a query
|
||||
func (e *Engine) SearchPackages(query string, limit int) []PackageStats {
|
||||
e.statsMu.RLock()
|
||||
defer e.statsMu.RUnlock()
|
||||
|
||||
results := make([]PackageStats, 0)
|
||||
for _, stats := range e.stats {
|
||||
// Simple substring search
|
||||
if contains(stats.Name, query) {
|
||||
results = append(results, *stats)
|
||||
}
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by downloads
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].TotalDownloads > results[j].TotalDownloads
|
||||
})
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// contains performs a case-insensitive substring search
|
||||
func contains(s, substr string) bool {
|
||||
sLower := toLower(s)
|
||||
substrLower := toLower(substr)
|
||||
return len(sLower) >= len(substrLower) &&
|
||||
findSubstring(sLower, substrLower)
|
||||
}
|
||||
|
||||
func toLower(s string) string {
|
||||
result := make([]byte, len(s))
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
result[i] = c + 32
|
||||
} else {
|
||||
result[i] = c
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(s) < len(substr) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if s[i+j] != substr[j] {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
+435
@@ -0,0 +1,435 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cdn"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/health"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/lock"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file"
|
||||
metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/prewarming"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/proxy/goproxy"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/proxy/npm"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/proxy/pypi"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage/filesystem"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/vcs"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// App represents the main application
|
||||
type App struct {
|
||||
config *config.Config
|
||||
app *fiber.App
|
||||
healthChecker *health.Checker
|
||||
cache *cache.Manager
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.Store
|
||||
authManager *auth.Manager
|
||||
networkClient *network.Client
|
||||
scanManager *scanner.Manager
|
||||
rescanWorker *scanner.RescanWorker
|
||||
analyticsEngine *analytics.Engine
|
||||
wsServer *websocket.Server
|
||||
prewarmWorker *prewarming.Worker
|
||||
lockManager *lock.Manager
|
||||
cdnMiddleware *cdn.Middleware
|
||||
}
|
||||
|
||||
// New creates a new application instance
|
||||
func New(cfg *config.Config) (*App, error) {
|
||||
app := &App{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// Initialize components
|
||||
if err := app.initializeComponents(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup HTTP server and routes
|
||||
if err := app.setupServer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// initializeComponents initializes all application components
|
||||
func (a *App) initializeComponents() error {
|
||||
var err error
|
||||
|
||||
// Initialize storage backend
|
||||
log.Info().Str("backend", a.config.Storage.Backend).Msg("Initializing storage backend")
|
||||
switch a.config.Storage.Backend {
|
||||
case "filesystem":
|
||||
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
|
||||
default:
|
||||
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize storage: %w", err)
|
||||
}
|
||||
|
||||
// Initialize metadata store
|
||||
log.Info().Str("backend", a.config.Metadata.Backend).Msg("Initializing metadata store")
|
||||
switch a.config.Metadata.Backend {
|
||||
case "sqlite":
|
||||
a.metadata, err = metasqlite.New(metasqlite.Config{
|
||||
Path: a.config.Metadata.Connection,
|
||||
})
|
||||
case "file":
|
||||
a.metadata, err = metafile.New(metafile.Config{
|
||||
Path: a.config.Metadata.Connection,
|
||||
})
|
||||
default:
|
||||
a.metadata, err = metasqlite.New(metasqlite.Config{
|
||||
Path: "gohoarder.db",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize metadata: %w", err)
|
||||
}
|
||||
|
||||
// Initialize scanner manager first (before cache)
|
||||
log.Info().Msg("Initializing security scanner")
|
||||
a.scanManager, err = scanner.New(a.config.Security, a.metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize scanner: %w", err)
|
||||
}
|
||||
|
||||
// Initialize cache manager with scanner
|
||||
log.Info().Msg("Initializing cache manager")
|
||||
a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, cache.Config{
|
||||
DefaultTTL: a.config.Cache.DefaultTTL,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize cache: %w", err)
|
||||
}
|
||||
|
||||
// Initialize network client
|
||||
log.Info().Msg("Initializing network client")
|
||||
a.networkClient = network.NewClient(network.Config{
|
||||
Timeout: 5 * time.Minute,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
RateLimit: 100,
|
||||
RateBurst: 10,
|
||||
CircuitBreaker: network.CircuitBreakerConfig{
|
||||
Enabled: true,
|
||||
FailureThreshold: 5,
|
||||
SuccessThreshold: 2,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
UserAgent: "GoHoarder/1.0",
|
||||
})
|
||||
|
||||
// Initialize authentication manager
|
||||
log.Info().Msg("Initializing authentication manager")
|
||||
a.authManager = auth.New()
|
||||
|
||||
// Initialize rescan worker if enabled
|
||||
if a.config.Security.Enabled && a.config.Security.RescanInterval > 0 {
|
||||
log.Info().Dur("interval", a.config.Security.RescanInterval).Msg("Initializing package rescan worker")
|
||||
a.rescanWorker = scanner.NewRescanWorker(a.scanManager, a.metadata, a.storage, a.config.Security.RescanInterval)
|
||||
}
|
||||
|
||||
// Initialize analytics engine
|
||||
log.Info().Msg("Initializing analytics engine")
|
||||
a.analyticsEngine = analytics.NewEngine(analytics.Config{
|
||||
MaxEvents: 10000,
|
||||
FlushInterval: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Initialize WebSocket server
|
||||
log.Info().Msg("Initializing WebSocket server")
|
||||
a.wsServer = websocket.NewServer(websocket.Config{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(_ *http.Request) bool {
|
||||
return true // Allow all origins in development
|
||||
},
|
||||
})
|
||||
|
||||
// Initialize pre-warming worker
|
||||
log.Info().Msg("Initializing pre-warming worker")
|
||||
a.prewarmWorker = prewarming.NewWorker(prewarming.Config{
|
||||
Enabled: false, // Disabled by default
|
||||
Interval: 1 * time.Hour,
|
||||
MaxConcurrent: 5,
|
||||
CacheManager: a.cache,
|
||||
Analytics: a.analyticsEngine,
|
||||
NetworkClient: a.networkClient,
|
||||
})
|
||||
|
||||
// Initialize CDN middleware
|
||||
log.Info().Msg("Initializing CDN middleware")
|
||||
a.cdnMiddleware = cdn.NewMiddleware(cdn.Config{
|
||||
DefaultCacheControl: cdn.CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600,
|
||||
SMaxAge: 7200,
|
||||
},
|
||||
EnableETag: true,
|
||||
EnableVary: true,
|
||||
})
|
||||
|
||||
// Initialize health checker
|
||||
a.healthChecker = health.New()
|
||||
a.healthChecker.AddCheck("storage", func(ctx context.Context) (health.Status, string) {
|
||||
if err := a.storage.Health(ctx); err != nil {
|
||||
return health.StatusUnhealthy, err.Error()
|
||||
}
|
||||
return health.StatusHealthy, ""
|
||||
})
|
||||
a.healthChecker.AddCheck("metadata", func(ctx context.Context) (health.Status, string) {
|
||||
if err := a.metadata.Health(ctx); err != nil {
|
||||
return health.StatusUnhealthy, err.Error()
|
||||
}
|
||||
return health.StatusHealthy, ""
|
||||
})
|
||||
a.healthChecker.AddCheck("cache", func(ctx context.Context) (health.Status, string) {
|
||||
return health.StatusHealthy, "" // Cache is always healthy if initialized
|
||||
})
|
||||
a.healthChecker.AddCheck("scanner", func(ctx context.Context) (health.Status, string) {
|
||||
if a.config.Security.Enabled {
|
||||
if err := a.scanManager.Health(ctx); err != nil {
|
||||
return health.StatusUnhealthy, err.Error()
|
||||
}
|
||||
}
|
||||
return health.StatusHealthy, ""
|
||||
})
|
||||
|
||||
log.Info().Msg("All components initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupServer sets up the Fiber server and routes
|
||||
func (a *App) setupServer() error {
|
||||
// Create Fiber app
|
||||
a.app = fiber.New(fiber.Config{
|
||||
ReadTimeout: a.config.Server.ReadTimeout,
|
||||
WriteTimeout: a.config.Server.WriteTimeout,
|
||||
ServerHeader: "GoHoarder",
|
||||
AppName: "GoHoarder v1.0",
|
||||
})
|
||||
|
||||
// Health and metrics endpoints (adapted from net/http)
|
||||
a.app.Get("/health", adaptor.HTTPHandlerFunc(a.healthChecker.HealthHandler()))
|
||||
a.app.Get("/health/ready", adaptor.HTTPHandlerFunc(a.healthChecker.ReadyHandler()))
|
||||
a.app.Get("/metrics", adaptor.HTTPHandler(metrics.Handler()))
|
||||
|
||||
// WebSocket endpoint (adapted from net/http)
|
||||
a.app.Get("/ws", adaptor.HTTPHandlerFunc(a.wsServer.HandleWebSocket))
|
||||
|
||||
// API endpoints
|
||||
a.app.Get("/api/config", a.handleConfig)
|
||||
a.app.All("/api/packages/*", a.handlePackages) // Handles packages and vulnerabilities
|
||||
a.app.Get("/api/stats", a.handleStats)
|
||||
a.app.Get("/api/stats/timeseries", a.handleTimeSeriesStats)
|
||||
a.app.Get("/api/info", a.handleInfo)
|
||||
|
||||
// Admin endpoints (bypass management)
|
||||
a.app.All("/api/admin/bypasses/:id?", a.requireAdmin, a.handleAdminBypasses)
|
||||
|
||||
// Proxy handlers (adapted from net/http)
|
||||
// Load git credentials if configured
|
||||
var credStore *vcs.CredentialStore
|
||||
if a.config.Handlers.Go.GitCredentialsFile != "" {
|
||||
credStore = vcs.NewCredentialStore()
|
||||
if err := credStore.LoadFromFile(a.config.Handlers.Go.GitCredentialsFile); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("file", a.config.Handlers.Go.GitCredentialsFile).
|
||||
Msg("Failed to load git credentials, continuing without pattern-based credentials")
|
||||
} else if err := credStore.ValidateConfig(); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("file", a.config.Handlers.Go.GitCredentialsFile).
|
||||
Msg("Invalid git credentials configuration, continuing without pattern-based credentials")
|
||||
credStore = nil
|
||||
}
|
||||
}
|
||||
|
||||
goProxyHandler := goproxy.New(a.cache, a.networkClient, goproxy.Config{
|
||||
Upstream: "https://proxy.golang.org",
|
||||
SumDBURL: "https://sum.golang.org",
|
||||
CredStore: credStore,
|
||||
})
|
||||
a.app.All("/go/*", adaptor.HTTPHandler(http.StripPrefix("/go", goProxyHandler)))
|
||||
|
||||
npmProxyHandler := npm.New(a.cache, a.networkClient, npm.Config{
|
||||
Upstream: "https://registry.npmjs.org",
|
||||
})
|
||||
a.app.All("/npm/*", adaptor.HTTPHandler(http.StripPrefix("/npm", npmProxyHandler)))
|
||||
|
||||
pypiProxyHandler := pypi.New(a.cache, a.networkClient, pypi.Config{
|
||||
Upstream: "https://pypi.org/simple",
|
||||
})
|
||||
a.app.All("/pypi/*", adaptor.HTTPHandler(http.StripPrefix("/pypi", pypiProxyHandler)))
|
||||
|
||||
// Serve frontend static files
|
||||
frontendDir := "frontend/dist"
|
||||
if _, err := os.Stat(frontendDir); err == nil {
|
||||
log.Info().Str("dir", frontendDir).Msg("Serving frontend static files")
|
||||
a.app.Static("/", frontendDir)
|
||||
} else {
|
||||
log.Warn().Msg("Frontend dist directory not found, frontend won't be served")
|
||||
a.app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.Type("html").SendString(`
|
||||
<html>
|
||||
<head><title>GoHoarder</title></head>
|
||||
<body>
|
||||
<h1>GoHoarder Package Cache Proxy</h1>
|
||||
<p>Frontend not built. Build with: <code>cd frontend && npm run build</code></p>
|
||||
<h2>Available Endpoints:</h2>
|
||||
<ul>
|
||||
<li><a href="/health">Health Check</a></li>
|
||||
<li><a href="/metrics">Metrics</a></li>
|
||||
<li><a href="/api/stats">Statistics API</a></li>
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
`)
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("addr", fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)).
|
||||
Msg("Fiber server configured")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run starts the application
|
||||
func (a *App) Run() error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start WebSocket server
|
||||
a.wsServer.Start(ctx)
|
||||
|
||||
// Start pre-warming worker
|
||||
a.prewarmWorker.Start(ctx)
|
||||
|
||||
// Start rescan worker if enabled
|
||||
if a.rescanWorker != nil {
|
||||
go a.rescanWorker.Start(ctx)
|
||||
}
|
||||
|
||||
// Start download data aggregation worker (runs every hour)
|
||||
go a.startAggregationWorker(ctx)
|
||||
|
||||
// Start Fiber server in goroutine
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||
log.Info().
|
||||
Str("addr", addr).
|
||||
Msg("Starting Fiber server")
|
||||
if err := a.app.Listen(addr); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("server error: %w", err)
|
||||
case sig := <-sigChan:
|
||||
log.Info().
|
||||
Str("signal", sig.String()).
|
||||
Msg("Shutdown signal received")
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
return a.Shutdown()
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the application
|
||||
func (a *App) Shutdown() error {
|
||||
log.Info().Msg("Starting graceful shutdown")
|
||||
|
||||
// Stop Fiber server
|
||||
if err := a.app.Shutdown(); err != nil {
|
||||
log.Error().Err(err).Msg("Error shutting down Fiber server")
|
||||
}
|
||||
|
||||
// Stop pre-warming worker
|
||||
a.prewarmWorker.Stop()
|
||||
|
||||
// Stop rescan worker if running
|
||||
if a.rescanWorker != nil {
|
||||
a.rescanWorker.Stop()
|
||||
}
|
||||
|
||||
// Close analytics engine
|
||||
a.analyticsEngine.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Close storage
|
||||
if err := a.storage.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Error closing storage")
|
||||
}
|
||||
|
||||
// Close metadata store
|
||||
if err := a.metadata.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Error closing metadata store")
|
||||
}
|
||||
|
||||
// Close lock manager if initialized
|
||||
if a.lockManager != nil {
|
||||
if err := a.lockManager.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Error closing lock manager")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msg("Shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// startAggregationWorker runs download data aggregation periodically
|
||||
func (a *App) startAggregationWorker(ctx context.Context) {
|
||||
log.Info().Msg("Starting download data aggregation worker (runs every hour)")
|
||||
|
||||
// Run immediately on startup
|
||||
if err := a.metadata.AggregateDownloadData(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to run initial download data aggregation")
|
||||
}
|
||||
|
||||
// Then run every hour
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Aggregation worker stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := a.metadata.AggregateDownloadData(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to aggregate download data")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,512 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/internal/version"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handlePackages handles /api/packages endpoint
|
||||
func (a *App) handlePackages(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
// Check if this is a vulnerability endpoint request
|
||||
if strings.HasSuffix(c.Path(), "/vulnerabilities") {
|
||||
return a.handleVulnerabilities(c)
|
||||
}
|
||||
|
||||
switch c.Method() {
|
||||
case "GET":
|
||||
return a.handleListPackages(c)
|
||||
case "DELETE":
|
||||
return a.handleDeletePackage(c)
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
}
|
||||
|
||||
// handleListPackages returns list of cached packages
|
||||
func (a *App) handleListPackages(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Get packages from metadata store
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000, // Get more to account for duplicates
|
||||
Offset: 0,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list packages"})
|
||||
}
|
||||
|
||||
log.Debug().Int("total_packages_from_db", len(allPackages)).Msg("Retrieved packages from database")
|
||||
|
||||
// Filter, clean, and deduplicate packages
|
||||
// Map stores both cleaned package and original name for scan lookups
|
||||
type packageEntry struct {
|
||||
pkg *metadata.Package
|
||||
originalName string
|
||||
}
|
||||
seen := make(map[string]*packageEntry)
|
||||
skippedCount := 0
|
||||
for _, pkg := range allPackages {
|
||||
// Skip metadata entries (npm metadata pages, pypi pages, etc.)
|
||||
if pkg.Version == "list" || pkg.Version == "latest" || pkg.Version == "metadata" || pkg.Version == "page" {
|
||||
skippedCount++
|
||||
log.Debug().
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Str("registry", pkg.Registry).
|
||||
Msg("Skipping metadata entry")
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean the package name (remove /@v/version.ext suffix)
|
||||
originalName := pkg.Name
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
// Create deduplication key
|
||||
key := cleanName + "@" + pkg.Version
|
||||
|
||||
// Keep the entry with the largest size (typically .zip files)
|
||||
if existing, ok := seen[key]; !ok || pkg.Size > existing.pkg.Size {
|
||||
// Create a copy with cleaned name
|
||||
cleanPkg := *pkg
|
||||
cleanPkg.Name = cleanName
|
||||
seen[key] = &packageEntry{
|
||||
pkg: &cleanPkg,
|
||||
originalName: originalName,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("skipped_metadata", skippedCount).
|
||||
Int("unique_packages", len(seen)).
|
||||
Msg("Filtered and deduplicated packages")
|
||||
|
||||
// Convert map to slice, keeping track of original names
|
||||
type packageWithOriginalName struct {
|
||||
pkg *metadata.Package
|
||||
originalName string
|
||||
}
|
||||
packagesWithNames := make([]packageWithOriginalName, 0, len(seen))
|
||||
for _, entry := range seen {
|
||||
packagesWithNames = append(packagesWithNames, packageWithOriginalName{
|
||||
pkg: entry.pkg,
|
||||
originalName: entry.originalName,
|
||||
})
|
||||
}
|
||||
|
||||
// Enhance packages with vulnerability information if security scanning is enabled
|
||||
var response map[string]interface{}
|
||||
if a.config.Security.Enabled {
|
||||
enhancedPackages := make([]map[string]interface{}, 0, len(packagesWithNames))
|
||||
for _, entry := range packagesWithNames {
|
||||
pkg := entry.pkg
|
||||
pkgMap := map[string]interface{}{
|
||||
"id": pkg.ID,
|
||||
"registry": pkg.Registry,
|
||||
"name": pkg.Name,
|
||||
"version": pkg.Version,
|
||||
"size": pkg.Size,
|
||||
"checksum_sha256": pkg.ChecksumSHA256,
|
||||
"cached_at": pkg.CachedAt,
|
||||
"last_accessed": pkg.LastAccessed,
|
||||
"download_count": pkg.DownloadCount,
|
||||
}
|
||||
|
||||
// Add vulnerability info if scanned
|
||||
if pkg.SecurityScanned {
|
||||
// Use original name for scan result lookup (handles Go packages with /@v/ suffix)
|
||||
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, entry.originalName, pkg.Version)
|
||||
if err == nil && scanResult != nil {
|
||||
// Count vulnerabilities by severity
|
||||
severityCounts := make(map[string]int)
|
||||
for _, vuln := range scanResult.Vulnerabilities {
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": true,
|
||||
"status": scanResult.Status,
|
||||
"scannedAt": scanResult.ScannedAt.Format(time.RFC3339),
|
||||
"counts": map[string]int{
|
||||
"critical": severityCounts["CRITICAL"],
|
||||
"high": severityCounts["HIGH"],
|
||||
"moderate": severityCounts["MODERATE"],
|
||||
"low": severityCounts["LOW"],
|
||||
},
|
||||
"total": scanResult.VulnerabilityCount,
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "not_scanned",
|
||||
}
|
||||
}
|
||||
|
||||
enhancedPackages = append(enhancedPackages, pkgMap)
|
||||
}
|
||||
|
||||
response = map[string]interface{}{
|
||||
"packages": enhancedPackages,
|
||||
"total": len(enhancedPackages),
|
||||
}
|
||||
} else {
|
||||
// Non-enhanced mode - just return the packages
|
||||
packages := make([]*metadata.Package, 0, len(packagesWithNames))
|
||||
for _, entry := range packagesWithNames {
|
||||
packages = append(packages, entry.pkg)
|
||||
}
|
||||
response = map[string]interface{}{
|
||||
"packages": packages,
|
||||
"total": len(packages),
|
||||
}
|
||||
}
|
||||
|
||||
// Success response
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
}
|
||||
|
||||
// handleDeletePackage deletes a cached package
|
||||
func (a *App) handleDeletePackage(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Parse path: /api/packages/{registry}/{name}/{version}
|
||||
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
|
||||
// Version is always the last segment
|
||||
path := strings.TrimPrefix(c.Path(), "/api/packages/")
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 3 {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid path format, expected /api/packages/{registry}/{name}/{version}",
|
||||
})
|
||||
}
|
||||
|
||||
registry := parts[0]
|
||||
version := parts[len(parts)-1]
|
||||
name := strings.Join(parts[1:len(parts)-1], "/")
|
||||
|
||||
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
|
||||
// For other registries, we can delete directly
|
||||
var deletedCount int
|
||||
var lastErr error
|
||||
|
||||
if registry == "go" {
|
||||
// List all packages matching the base name and version
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages for deletion")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list packages"})
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Int("total_packages", len(allPackages)).
|
||||
Msg("Searching for packages to delete")
|
||||
|
||||
// Find and delete all entries for this package
|
||||
for _, pkg := range allPackages {
|
||||
if pkg.Registry != registry || pkg.Version != version {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this package name matches (either exact or with /@v/ suffix)
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("db_name", pkg.Name).
|
||||
Str("clean_name", cleanName).
|
||||
Str("search_name", name).
|
||||
Bool("matches", cleanName == name).
|
||||
Msg("Checking package")
|
||||
|
||||
if cleanName == name {
|
||||
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Failed to delete package variant")
|
||||
lastErr = err
|
||||
} else {
|
||||
deletedCount++
|
||||
log.Info().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Deleted package variant")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("deleted_count", deletedCount).
|
||||
Msg("Delete operation completed")
|
||||
|
||||
if deletedCount == 0 {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "package not found"})
|
||||
}
|
||||
|
||||
if lastErr != nil && deletedCount == 0 {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete package"})
|
||||
}
|
||||
} else {
|
||||
// For NPM and PyPI, delete directly
|
||||
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to delete package")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete package"})
|
||||
}
|
||||
deletedCount = 1
|
||||
}
|
||||
|
||||
// Broadcast event via WebSocket
|
||||
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
})
|
||||
|
||||
// Success response
|
||||
response := map[string]interface{}{
|
||||
"deleted": true,
|
||||
"package": map[string]string{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
},
|
||||
}
|
||||
|
||||
// For Go packages, include count of deleted variants
|
||||
if registry == "go" {
|
||||
response["deleted_count"] = deletedCount
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
}
|
||||
|
||||
// handleStats handles /api/stats endpoint
|
||||
func (a *App) handleStats(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
if c.Method() != "GET" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
|
||||
ctx := c.Context()
|
||||
|
||||
// Get cache statistics for all registries from database
|
||||
cacheStats, err := a.cache.GetStats(ctx, "")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get cache stats")
|
||||
cacheStats = &metadata.Stats{}
|
||||
}
|
||||
|
||||
// Get all packages to calculate per-registry breakdown
|
||||
packages, err := a.metadata.ListPackages(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
packages = []*metadata.Package{}
|
||||
}
|
||||
|
||||
// Calculate per-registry breakdown (exclude metadata entries like "list", "latest")
|
||||
registryStats := make(map[string]map[string]interface{})
|
||||
|
||||
for _, pkg := range packages {
|
||||
// Skip metadata entries (npm metadata pages, pypi pages, etc.)
|
||||
if pkg.Version == "list" || pkg.Version == "latest" || pkg.Version == "metadata" || pkg.Version == "page" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track per-registry stats
|
||||
if _, ok := registryStats[pkg.Registry]; !ok {
|
||||
registryStats[pkg.Registry] = map[string]interface{}{
|
||||
"count": 0,
|
||||
"size": int64(0),
|
||||
"downloads": int64(0),
|
||||
}
|
||||
}
|
||||
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
|
||||
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
|
||||
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
|
||||
}
|
||||
|
||||
// Combine statistics using database stats for accuracy
|
||||
stats := map[string]interface{}{
|
||||
"total_packages": cacheStats.TotalPackages,
|
||||
"total_downloads": cacheStats.TotalDownloads,
|
||||
"total_size": cacheStats.TotalSize,
|
||||
"cache_hits": cacheStats.TotalDownloads,
|
||||
"cache_misses": 0, // TODO: Track cache misses
|
||||
"cache_evictions": 0, // TODO: Track evictions
|
||||
"cache_size": cacheStats.TotalSize,
|
||||
"scanned_packages": cacheStats.ScannedPackages,
|
||||
"vulnerable_packages": cacheStats.VulnerablePackages,
|
||||
}
|
||||
|
||||
// Convert registry stats to interface map
|
||||
registries := make(map[string]interface{})
|
||||
for registry, regStats := range registryStats {
|
||||
registries[registry] = regStats
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"stats": stats,
|
||||
"registries": registries,
|
||||
})
|
||||
}
|
||||
|
||||
// handleTimeSeriesStats handles /api/stats/timeseries endpoint
|
||||
// Returns time-series download statistics for charts
|
||||
func (a *App) handleTimeSeriesStats(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
if c.Method() != "GET" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
|
||||
ctx := c.Context()
|
||||
|
||||
// Get query parameters
|
||||
period := c.Query("period", "1day") // Default to 1 day
|
||||
registry := c.Query("registry") // Optional registry filter
|
||||
|
||||
// Validate period
|
||||
validPeriods := map[string]bool{"1h": true, "1day": true, "7day": true, "30day": true}
|
||||
if !validPeriods[period] {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid period, must be one of: 1h, 1day, 7day, 30day",
|
||||
})
|
||||
}
|
||||
|
||||
// Get time-series stats
|
||||
stats, err := a.metadata.GetTimeSeriesStats(ctx, period, registry)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("period", period).Str("registry", registry).Msg("Failed to get time-series stats")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": "failed to get time-series statistics",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
// handleConfig handles /api/config endpoint
|
||||
// Returns runtime configuration for the frontend
|
||||
func (a *App) handleConfig(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
if c.Method() != "GET" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
|
||||
// Build server URL from request
|
||||
scheme := "http"
|
||||
if c.Protocol() == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
serverURL := scheme + "://" + c.Hostname()
|
||||
|
||||
config := map[string]interface{}{
|
||||
"server_url": serverURL,
|
||||
"version": version.Version,
|
||||
"features": map[string]bool{
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"websockets": true,
|
||||
},
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(config)
|
||||
}
|
||||
|
||||
// handleInfo handles /api/info endpoint
|
||||
func (a *App) handleInfo(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
if c.Method() != "GET" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"name": "GoHoarder",
|
||||
"version": version.Version,
|
||||
"config": map[string]interface{}{
|
||||
"storage_backend": a.config.Storage.Backend,
|
||||
"metadata_backend": a.config.Metadata.Backend,
|
||||
"cache_ttl": a.config.Cache.DefaultTTL.String(),
|
||||
"max_cache_size": a.config.Cache.MaxSizeBytes,
|
||||
},
|
||||
"features": map[string]bool{
|
||||
"distributed_locking": a.lockManager != nil,
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"pre_warming": a.prewarmWorker != nil,
|
||||
"websockets": true,
|
||||
"analytics": true,
|
||||
},
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(info)
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/internal/version"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handlePackages handles /api/packages endpoint
|
||||
func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a vulnerability endpoint request
|
||||
if strings.HasSuffix(r.URL.Path, "/vulnerabilities") {
|
||||
a.handleVulnerabilities(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
a.handleListPackages(w, r)
|
||||
case "DELETE":
|
||||
a.handleDeletePackage(w, r)
|
||||
default:
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
}
|
||||
}
|
||||
|
||||
// handleListPackages returns list of cached packages
|
||||
func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get packages from metadata store
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000, // Get more to account for duplicates
|
||||
Offset: 0,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
|
||||
return
|
||||
}
|
||||
|
||||
// Filter, clean, and deduplicate packages
|
||||
seen := make(map[string]*metadata.Package)
|
||||
for _, pkg := range allPackages {
|
||||
// Skip metadata entries
|
||||
if pkg.Version == "list" || pkg.Version == "latest" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean the package name (remove /@v/version.ext suffix)
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
// Create deduplication key
|
||||
key := cleanName + "@" + pkg.Version
|
||||
|
||||
// Keep the entry with the largest size (typically .zip files)
|
||||
if existing, ok := seen[key]; !ok || pkg.Size > existing.Size {
|
||||
// Create a copy with cleaned name
|
||||
cleanPkg := *pkg
|
||||
cleanPkg.Name = cleanName
|
||||
seen[key] = &cleanPkg
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
packages := make([]*metadata.Package, 0, len(seen))
|
||||
for _, pkg := range seen {
|
||||
packages = append(packages, pkg)
|
||||
}
|
||||
|
||||
// Enhance packages with vulnerability information if security scanning is enabled
|
||||
var response map[string]interface{}
|
||||
if a.config.Security.Enabled {
|
||||
enhancedPackages := make([]map[string]interface{}, 0, len(packages))
|
||||
for _, pkg := range packages {
|
||||
pkgMap := map[string]interface{}{
|
||||
"id": pkg.ID,
|
||||
"registry": pkg.Registry,
|
||||
"name": pkg.Name,
|
||||
"version": pkg.Version,
|
||||
"size": pkg.Size,
|
||||
"checksum_sha256": pkg.ChecksumSHA256,
|
||||
"cached_at": pkg.CachedAt,
|
||||
"last_accessed": pkg.LastAccessed,
|
||||
"download_count": pkg.DownloadCount,
|
||||
}
|
||||
|
||||
// Add vulnerability info if scanned
|
||||
if pkg.SecurityScanned {
|
||||
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
|
||||
if err == nil && scanResult != nil {
|
||||
// Count vulnerabilities by severity
|
||||
severityCounts := make(map[string]int)
|
||||
for _, vuln := range scanResult.Vulnerabilities {
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": true,
|
||||
"status": scanResult.Status,
|
||||
"counts": map[string]int{
|
||||
"critical": severityCounts["CRITICAL"],
|
||||
"high": severityCounts["HIGH"],
|
||||
"medium": severityCounts["MEDIUM"],
|
||||
"low": severityCounts["LOW"],
|
||||
},
|
||||
"total": scanResult.VulnerabilityCount,
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "not_scanned",
|
||||
}
|
||||
}
|
||||
|
||||
enhancedPackages = append(enhancedPackages, pkgMap)
|
||||
}
|
||||
|
||||
response = map[string]interface{}{
|
||||
"packages": enhancedPackages,
|
||||
"total": len(enhancedPackages),
|
||||
}
|
||||
} else {
|
||||
response = map[string]interface{}{
|
||||
"packages": packages,
|
||||
"total": len(packages),
|
||||
}
|
||||
}
|
||||
|
||||
// Success response
|
||||
errors.WriteJSONSimple(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleDeletePackage deletes a cached package
|
||||
func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Parse path: /api/packages/{registry}/{name}/{version}
|
||||
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
|
||||
// Version is always the last segment
|
||||
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 3 {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}"))
|
||||
return
|
||||
}
|
||||
|
||||
registry := parts[0]
|
||||
version := parts[len(parts)-1]
|
||||
name := strings.Join(parts[1:len(parts)-1], "/")
|
||||
|
||||
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
|
||||
// For other registries, we can delete directly
|
||||
var deletedCount int
|
||||
var lastErr error
|
||||
|
||||
if registry == "go" {
|
||||
// List all packages matching the base name and version
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages for deletion")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Int("total_packages", len(allPackages)).
|
||||
Msg("Searching for packages to delete")
|
||||
|
||||
// Find and delete all entries for this package
|
||||
for _, pkg := range allPackages {
|
||||
if pkg.Registry != registry || pkg.Version != version {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this package name matches (either exact or with /@v/ suffix)
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("db_name", pkg.Name).
|
||||
Str("clean_name", cleanName).
|
||||
Str("search_name", name).
|
||||
Bool("matches", cleanName == name).
|
||||
Msg("Checking package")
|
||||
|
||||
if cleanName == name {
|
||||
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Failed to delete package variant")
|
||||
lastErr = err
|
||||
} else {
|
||||
deletedCount++
|
||||
log.Info().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Deleted package variant")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("deleted_count", deletedCount).
|
||||
Msg("Delete operation completed")
|
||||
|
||||
if deletedCount == 0 {
|
||||
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if lastErr != nil && deletedCount == 0 {
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// For NPM and PyPI, delete directly
|
||||
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to delete package")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
|
||||
return
|
||||
}
|
||||
deletedCount = 1
|
||||
}
|
||||
|
||||
// Broadcast event via WebSocket
|
||||
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
})
|
||||
|
||||
// Success response
|
||||
response := map[string]interface{}{
|
||||
"deleted": true,
|
||||
"package": map[string]string{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
},
|
||||
}
|
||||
|
||||
// For Go packages, include count of deleted variants
|
||||
if registry == "go" {
|
||||
response["deleted_count"] = deletedCount
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleStats handles /api/stats endpoint
|
||||
func (a *App) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Get cache statistics for all registries
|
||||
cacheStats, err := a.cache.GetStats(ctx, "")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get cache stats")
|
||||
cacheStats = &metadata.Stats{}
|
||||
}
|
||||
|
||||
// Get all packages to calculate total size and downloads
|
||||
packages, err := a.metadata.ListPackages(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
packages = []*metadata.Package{}
|
||||
}
|
||||
|
||||
// Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest")
|
||||
var totalSize int64
|
||||
var totalDownloads int64
|
||||
var actualPackageCount int
|
||||
registryStats := make(map[string]map[string]interface{})
|
||||
|
||||
for _, pkg := range packages {
|
||||
// Skip metadata entries
|
||||
if pkg.Version == "list" || pkg.Version == "latest" {
|
||||
continue
|
||||
}
|
||||
totalSize += pkg.Size
|
||||
totalDownloads += int64(pkg.DownloadCount)
|
||||
actualPackageCount++
|
||||
|
||||
// Track per-registry stats
|
||||
if _, ok := registryStats[pkg.Registry]; !ok {
|
||||
registryStats[pkg.Registry] = map[string]interface{}{
|
||||
"count": 0,
|
||||
"size": int64(0),
|
||||
"downloads": int64(0),
|
||||
}
|
||||
}
|
||||
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
|
||||
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
|
||||
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
|
||||
}
|
||||
|
||||
// Combine statistics
|
||||
stats := map[string]interface{}{
|
||||
"total_packages": actualPackageCount,
|
||||
"total_downloads": totalDownloads,
|
||||
"total_size": totalSize,
|
||||
"cache_hits": cacheStats.TotalDownloads,
|
||||
"cache_misses": 0, // TODO: Track cache misses
|
||||
"cache_evictions": 0, // TODO: Track evictions
|
||||
"cache_size": cacheStats.TotalSize,
|
||||
"scanned_packages": cacheStats.ScannedPackages,
|
||||
"vulnerable_packages": cacheStats.VulnerablePackages,
|
||||
}
|
||||
|
||||
// Convert registry stats to interface map
|
||||
registries := make(map[string]interface{})
|
||||
for registry, regStats := range registryStats {
|
||||
registries[registry] = regStats
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
|
||||
"stats": stats,
|
||||
"registries": registries,
|
||||
})
|
||||
}
|
||||
|
||||
// handleInfo handles /api/info endpoint
|
||||
func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
return
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"name": "GoHoarder",
|
||||
"version": version.Version,
|
||||
"config": map[string]interface{}{
|
||||
"storage_backend": a.config.Storage.Backend,
|
||||
"metadata_backend": a.config.Metadata.Backend,
|
||||
"cache_ttl": a.config.Cache.DefaultTTL.String(),
|
||||
"max_cache_size": a.config.Cache.MaxSizeBytes,
|
||||
},
|
||||
"features": map[string]bool{
|
||||
"distributed_locking": a.lockManager != nil,
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"pre_warming": a.prewarmWorker != nil,
|
||||
"websockets": true,
|
||||
"analytics": true,
|
||||
},
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, info)
|
||||
}
|
||||
@@ -0,0 +1,415 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/internal/version"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handlePackages handles /api/packages endpoint
|
||||
func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a vulnerability endpoint request
|
||||
if strings.HasSuffix(r.URL.Path, "/vulnerabilities") {
|
||||
a.handleVulnerabilities(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
a.handleListPackages(w, r)
|
||||
case "DELETE":
|
||||
a.handleDeletePackage(w, r)
|
||||
default:
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
}
|
||||
}
|
||||
|
||||
// handleListPackages returns list of cached packages
|
||||
func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Get packages from metadata store
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000, // Get more to account for duplicates
|
||||
Offset: 0,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
|
||||
return
|
||||
}
|
||||
|
||||
// Filter, clean, and deduplicate packages
|
||||
seen := make(map[string]*metadata.Package)
|
||||
for _, pkg := range allPackages {
|
||||
// Skip metadata entries
|
||||
if pkg.Version == "list" || pkg.Version == "latest" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Clean the package name (remove /@v/version.ext suffix)
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
// Create deduplication key
|
||||
key := cleanName + "@" + pkg.Version
|
||||
|
||||
// Keep the entry with the largest size (typically .zip files)
|
||||
if existing, ok := seen[key]; !ok || pkg.Size > existing.Size {
|
||||
// Create a copy with cleaned name
|
||||
cleanPkg := *pkg
|
||||
cleanPkg.Name = cleanName
|
||||
seen[key] = &cleanPkg
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
packages := make([]*metadata.Package, 0, len(seen))
|
||||
for _, pkg := range seen {
|
||||
packages = append(packages, pkg)
|
||||
}
|
||||
|
||||
// Enhance packages with vulnerability information if security scanning is enabled
|
||||
var response map[string]interface{}
|
||||
if a.config.Security.Enabled {
|
||||
enhancedPackages := make([]map[string]interface{}, 0, len(packages))
|
||||
for _, pkg := range packages {
|
||||
pkgMap := map[string]interface{}{
|
||||
"id": pkg.ID,
|
||||
"registry": pkg.Registry,
|
||||
"name": pkg.Name,
|
||||
"version": pkg.Version,
|
||||
"size": pkg.Size,
|
||||
"checksum_sha256": pkg.ChecksumSHA256,
|
||||
"cached_at": pkg.CachedAt,
|
||||
"last_accessed": pkg.LastAccessed,
|
||||
"download_count": pkg.DownloadCount,
|
||||
}
|
||||
|
||||
// Add vulnerability info if scanned
|
||||
if pkg.SecurityScanned {
|
||||
scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
|
||||
if err == nil && scanResult != nil {
|
||||
// Count vulnerabilities by severity
|
||||
severityCounts := make(map[string]int)
|
||||
for _, vuln := range scanResult.Vulnerabilities {
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": true,
|
||||
"status": scanResult.Status,
|
||||
"scannedAt": scanResult.ScannedAt.Format(time.RFC3339),
|
||||
"counts": map[string]int{
|
||||
"critical": severityCounts["CRITICAL"],
|
||||
"high": severityCounts["HIGH"],
|
||||
"medium": severityCounts["MEDIUM"],
|
||||
"low": severityCounts["LOW"],
|
||||
},
|
||||
"total": scanResult.VulnerabilityCount,
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "not_scanned",
|
||||
}
|
||||
}
|
||||
|
||||
enhancedPackages = append(enhancedPackages, pkgMap)
|
||||
}
|
||||
|
||||
response = map[string]interface{}{
|
||||
"packages": enhancedPackages,
|
||||
"total": len(enhancedPackages),
|
||||
}
|
||||
} else {
|
||||
response = map[string]interface{}{
|
||||
"packages": packages,
|
||||
"total": len(packages),
|
||||
}
|
||||
}
|
||||
|
||||
// Success response
|
||||
errors.WriteJSONSimple(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleDeletePackage deletes a cached package
|
||||
func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Parse path: /api/packages/{registry}/{name}/{version}
|
||||
// For Go packages, name can contain slashes (e.g., github.com/user/repo)
|
||||
// Version is always the last segment
|
||||
path := strings.TrimPrefix(r.URL.Path, "/api/packages/")
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 3 {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}"))
|
||||
return
|
||||
}
|
||||
|
||||
registry := parts[0]
|
||||
version := parts[len(parts)-1]
|
||||
name := strings.Join(parts[1:len(parts)-1], "/")
|
||||
|
||||
// For Go packages, we need to find and delete all cache entries (.info, .mod, .zip)
|
||||
// For other registries, we can delete directly
|
||||
var deletedCount int
|
||||
var lastErr error
|
||||
|
||||
if registry == "go" {
|
||||
// List all packages matching the base name and version
|
||||
allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{
|
||||
Limit: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages for deletion")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages"))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Int("total_packages", len(allPackages)).
|
||||
Msg("Searching for packages to delete")
|
||||
|
||||
// Find and delete all entries for this package
|
||||
for _, pkg := range allPackages {
|
||||
if pkg.Registry != registry || pkg.Version != version {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this package name matches (either exact or with /@v/ suffix)
|
||||
cleanName := pkg.Name
|
||||
if idx := strings.Index(cleanName, "/@v/"); idx != -1 {
|
||||
cleanName = cleanName[:idx]
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("db_name", pkg.Name).
|
||||
Str("clean_name", cleanName).
|
||||
Str("search_name", name).
|
||||
Bool("matches", cleanName == name).
|
||||
Msg("Checking package")
|
||||
|
||||
if cleanName == name {
|
||||
if err := a.cache.Delete(ctx, pkg.Registry, pkg.Name, pkg.Version); err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Failed to delete package variant")
|
||||
lastErr = err
|
||||
} else {
|
||||
deletedCount++
|
||||
log.Info().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("name", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Deleted package variant")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("deleted_count", deletedCount).
|
||||
Msg("Delete operation completed")
|
||||
|
||||
if deletedCount == 0 {
|
||||
errors.WriteErrorSimple(w, errors.NotFound("package not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if lastErr != nil && deletedCount == 0 {
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// For NPM and PyPI, delete directly
|
||||
if err := a.cache.Delete(ctx, registry, name, version); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to delete package")
|
||||
errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package"))
|
||||
return
|
||||
}
|
||||
deletedCount = 1
|
||||
}
|
||||
|
||||
// Broadcast event via WebSocket
|
||||
a.wsServer.Broadcast(websocket.EventPackageDeleted, map[string]interface{}{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
})
|
||||
|
||||
// Success response
|
||||
response := map[string]interface{}{
|
||||
"deleted": true,
|
||||
"package": map[string]string{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
},
|
||||
}
|
||||
|
||||
// For Go packages, include count of deleted variants
|
||||
if registry == "go" {
|
||||
response["deleted_count"] = deletedCount
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// handleStats handles /api/stats endpoint
|
||||
func (a *App) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Get cache statistics for all registries
|
||||
cacheStats, err := a.cache.GetStats(ctx, "")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get cache stats")
|
||||
cacheStats = &metadata.Stats{}
|
||||
}
|
||||
|
||||
// Get all packages to calculate total size and downloads
|
||||
packages, err := a.metadata.ListPackages(ctx, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages")
|
||||
packages = []*metadata.Package{}
|
||||
}
|
||||
|
||||
// Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest")
|
||||
var totalSize int64
|
||||
var totalDownloads int64
|
||||
var actualPackageCount int
|
||||
registryStats := make(map[string]map[string]interface{})
|
||||
|
||||
for _, pkg := range packages {
|
||||
// Skip metadata entries
|
||||
if pkg.Version == "list" || pkg.Version == "latest" {
|
||||
continue
|
||||
}
|
||||
totalSize += pkg.Size
|
||||
totalDownloads += int64(pkg.DownloadCount)
|
||||
actualPackageCount++
|
||||
|
||||
// Track per-registry stats
|
||||
if _, ok := registryStats[pkg.Registry]; !ok {
|
||||
registryStats[pkg.Registry] = map[string]interface{}{
|
||||
"count": 0,
|
||||
"size": int64(0),
|
||||
"downloads": int64(0),
|
||||
}
|
||||
}
|
||||
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
|
||||
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
|
||||
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
|
||||
}
|
||||
|
||||
// Combine statistics
|
||||
stats := map[string]interface{}{
|
||||
"total_packages": actualPackageCount,
|
||||
"total_downloads": totalDownloads,
|
||||
"total_size": totalSize,
|
||||
"cache_hits": cacheStats.TotalDownloads,
|
||||
"cache_misses": 0, // TODO: Track cache misses
|
||||
"cache_evictions": 0, // TODO: Track evictions
|
||||
"cache_size": cacheStats.TotalSize,
|
||||
"scanned_packages": cacheStats.ScannedPackages,
|
||||
"vulnerable_packages": cacheStats.VulnerablePackages,
|
||||
}
|
||||
|
||||
// Convert registry stats to interface map
|
||||
registries := make(map[string]interface{})
|
||||
for registry, regStats := range registryStats {
|
||||
registries[registry] = regStats
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{
|
||||
"stats": stats,
|
||||
"registries": registries,
|
||||
})
|
||||
}
|
||||
|
||||
// handleInfo handles /api/info endpoint
|
||||
func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
errors.WriteErrorSimple(w, errors.BadRequest("method not allowed"))
|
||||
return
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"name": "GoHoarder",
|
||||
"version": version.Version,
|
||||
"config": map[string]interface{}{
|
||||
"storage_backend": a.config.Storage.Backend,
|
||||
"metadata_backend": a.config.Metadata.Backend,
|
||||
"cache_ttl": a.config.Cache.DefaultTTL.String(),
|
||||
"max_cache_size": a.config.Cache.MaxSizeBytes,
|
||||
},
|
||||
"features": map[string]bool{
|
||||
"distributed_locking": a.lockManager != nil,
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"pre_warming": a.prewarmWorker != nil,
|
||||
"websockets": true,
|
||||
"analytics": true,
|
||||
},
|
||||
}
|
||||
|
||||
errors.WriteJSONSimple(w, http.StatusOK, info)
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// requireAdmin middleware checks for admin authentication
|
||||
func (a *App) requireAdmin(c *fiber.Ctx) error {
|
||||
// Get API key from Authorization header
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "missing authorization header",
|
||||
})
|
||||
}
|
||||
|
||||
// Extract bearer token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "invalid authorization header format, expected: Bearer <token>",
|
||||
})
|
||||
}
|
||||
|
||||
apiKey := parts[1]
|
||||
|
||||
// Validate API key
|
||||
key, err := a.authManager.ValidateAPIKey(c.Context(), apiKey)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "invalid or expired API key",
|
||||
})
|
||||
}
|
||||
|
||||
// Check if user has admin role or bypass management permission
|
||||
if key.Role != auth.RoleAdmin && !key.HasPermission(auth.PermissionManageBypasses) {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
"error": "insufficient permissions, admin role required",
|
||||
})
|
||||
}
|
||||
|
||||
// Continue to next handler
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// handleAdminBypasses handles /api/admin/bypasses endpoint
|
||||
func (a *App) handleAdminBypasses(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
// Check if there's an ID parameter
|
||||
id := c.Params("id")
|
||||
|
||||
switch c.Method() {
|
||||
case "GET":
|
||||
if id != "" {
|
||||
return a.handleGetBypass(c)
|
||||
}
|
||||
return a.handleListBypasses(c)
|
||||
case "POST":
|
||||
return a.handleCreateBypass(c)
|
||||
case "PATCH":
|
||||
return a.handleUpdateBypass(c)
|
||||
case "DELETE":
|
||||
return a.handleDeleteBypass(c)
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
}
|
||||
|
||||
// handleListBypasses lists all CVE bypasses
|
||||
func (a *App) handleListBypasses(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Parse query parameters
|
||||
includeExpired := c.Query("include_expired") == "true"
|
||||
activeOnly := c.Query("active_only") == "true"
|
||||
bypassType := metadata.BypassType(c.Query("type"))
|
||||
|
||||
opts := &metadata.BypassListOptions{
|
||||
IncludeExpired: includeExpired,
|
||||
ActiveOnly: activeOnly,
|
||||
Type: bypassType,
|
||||
}
|
||||
|
||||
bypasses, err := a.metadata.ListCVEBypasses(ctx, opts)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list CVE bypasses")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list bypasses"})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"bypasses": bypasses,
|
||||
"total": len(bypasses),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateBypassRequest represents the request body for creating a bypass
|
||||
type CreateBypassRequest struct {
|
||||
Type metadata.BypassType `json:"type"` // "cve" or "package"
|
||||
Target string `json:"target"` // CVE ID or package name
|
||||
Reason string `json:"reason"` // Why this bypass is needed
|
||||
CreatedBy string `json:"created_by"` // Admin username
|
||||
ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration
|
||||
AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package
|
||||
NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired
|
||||
}
|
||||
|
||||
// handleCreateBypass creates a new CVE bypass
|
||||
func (a *App) handleCreateBypass(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
var req CreateBypassRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid JSON in request body"})
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Type != metadata.BypassTypeCVE && req.Type != metadata.BypassTypePackage {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "type must be 'cve' or 'package'"})
|
||||
}
|
||||
|
||||
if req.Target == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "target is required"})
|
||||
}
|
||||
|
||||
if req.Reason == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "reason is required"})
|
||||
}
|
||||
|
||||
if req.CreatedBy == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "created_by is required"})
|
||||
}
|
||||
|
||||
if req.ExpiresInHours <= 0 {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "expires_in_hours must be greater than 0"})
|
||||
}
|
||||
|
||||
// Create bypass
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(req.ExpiresInHours) * time.Hour)
|
||||
|
||||
bypass := &metadata.CVEBypass{
|
||||
ID: uuid.New().String(),
|
||||
Type: req.Type,
|
||||
Target: req.Target,
|
||||
Reason: req.Reason,
|
||||
CreatedBy: req.CreatedBy,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: expiresAt,
|
||||
AppliesTo: req.AppliesTo,
|
||||
NotifyOnExpiry: req.NotifyOnExpiry,
|
||||
Active: true,
|
||||
}
|
||||
|
||||
// Save to database
|
||||
if err := a.metadata.SaveCVEBypass(ctx, bypass); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to save CVE bypass")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to create bypass"})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("bypass_id", bypass.ID).
|
||||
Str("type", string(bypass.Type)).
|
||||
Str("target", bypass.Target).
|
||||
Str("created_by", bypass.CreatedBy).
|
||||
Time("expires_at", bypass.ExpiresAt).
|
||||
Msg("CVE bypass created")
|
||||
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"bypass": bypass,
|
||||
"message": "Bypass created successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetBypass gets a specific bypass by ID
|
||||
func (a *App) handleGetBypass(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Extract ID from parameter
|
||||
bypassID := c.Params("id")
|
||||
|
||||
if bypassID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"})
|
||||
}
|
||||
|
||||
// Get all bypasses and find the one with matching ID
|
||||
bypasses, err := a.metadata.ListCVEBypasses(ctx, &metadata.BypassListOptions{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list bypasses")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to get bypass"})
|
||||
}
|
||||
|
||||
for _, bypass := range bypasses {
|
||||
if bypass.ID == bypassID {
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"bypass": bypass,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"})
|
||||
}
|
||||
|
||||
// UpdateBypassRequest represents the request body for updating a bypass
|
||||
type UpdateBypassRequest struct {
|
||||
Active *bool `json:"active,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
ExpiresInHours int `json:"expires_in_hours,omitempty"`
|
||||
}
|
||||
|
||||
// handleUpdateBypass updates a bypass (activate/deactivate or extend expiration)
|
||||
func (a *App) handleUpdateBypass(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Extract ID from parameter
|
||||
bypassID := c.Params("id")
|
||||
|
||||
if bypassID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"})
|
||||
}
|
||||
|
||||
var req UpdateBypassRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid JSON in request body"})
|
||||
}
|
||||
|
||||
// Get current bypass
|
||||
bypasses, err := a.metadata.ListCVEBypasses(ctx, &metadata.BypassListOptions{
|
||||
IncludeExpired: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list bypasses")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to get bypass"})
|
||||
}
|
||||
|
||||
var currentBypass *metadata.CVEBypass
|
||||
for _, bypass := range bypasses {
|
||||
if bypass.ID == bypassID {
|
||||
currentBypass = bypass
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if currentBypass == nil {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"})
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Active != nil {
|
||||
currentBypass.Active = *req.Active
|
||||
}
|
||||
|
||||
if req.Reason != "" {
|
||||
currentBypass.Reason = req.Reason
|
||||
}
|
||||
|
||||
if req.ExpiresInHours > 0 {
|
||||
currentBypass.ExpiresAt = time.Now().Add(time.Duration(req.ExpiresInHours) * time.Hour)
|
||||
}
|
||||
|
||||
// Save updated bypass
|
||||
if err := a.metadata.SaveCVEBypass(ctx, currentBypass); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update bypass")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to update bypass"})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("bypass_id", currentBypass.ID).
|
||||
Bool("active", currentBypass.Active).
|
||||
Msg("CVE bypass updated")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"bypass": currentBypass,
|
||||
"message": "Bypass updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDeleteBypass deletes a bypass
|
||||
func (a *App) handleDeleteBypass(c *fiber.Ctx) error {
|
||||
ctx := c.Context()
|
||||
|
||||
// Extract ID from parameter
|
||||
bypassID := c.Params("id")
|
||||
|
||||
if bypassID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"})
|
||||
}
|
||||
|
||||
// Delete bypass
|
||||
if err := a.metadata.DeleteCVEBypass(ctx, bypassID); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"})
|
||||
}
|
||||
log.Error().Err(err).Msg("Failed to delete bypass")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete bypass"})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("bypass_id", bypassID).
|
||||
Msg("CVE bypass deleted")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"deleted": true,
|
||||
"bypass_id": bypassID,
|
||||
"message": "Bypass deleted successfully",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleVulnerabilities handles /api/packages/{registry}/{name}/{version}/vulnerabilities endpoint
|
||||
func (a *App) handleVulnerabilities(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
c.Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
c.Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if c.Method() == "OPTIONS" {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
if c.Method() != "GET" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"})
|
||||
}
|
||||
|
||||
ctx := c.Context()
|
||||
|
||||
// Parse path: /api/packages/{registry}/{name}/{version}/vulnerabilities
|
||||
path := strings.TrimPrefix(c.Path(), "/api/packages/")
|
||||
path = strings.TrimSuffix(path, "/vulnerabilities")
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 3 {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid path format, expected /api/packages/{registry}/{name}/{version}/vulnerabilities",
|
||||
})
|
||||
}
|
||||
|
||||
registry := parts[0]
|
||||
version := parts[len(parts)-1]
|
||||
name := strings.Join(parts[1:len(parts)-1], "/")
|
||||
|
||||
log.Debug().
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Getting vulnerabilities for package")
|
||||
|
||||
// Get scan result from metadata store
|
||||
scanResult, err := a.metadata.GetScanResult(ctx, registry, name, version)
|
||||
if err != nil {
|
||||
// Check if package exists
|
||||
pkg, pkgErr := a.metadata.GetPackage(ctx, registry, name, version)
|
||||
if pkgErr != nil {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "package not found"})
|
||||
}
|
||||
|
||||
// Package exists but not scanned yet
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"package": fiber.Map{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
},
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
"vulnerabilities": []interface{}{},
|
||||
"vulnerability_count": 0,
|
||||
"message": "Package not yet scanned for vulnerabilities",
|
||||
"security_scanned": pkg.SecurityScanned,
|
||||
})
|
||||
}
|
||||
|
||||
// Get active bypasses to show which vulnerabilities are bypassed
|
||||
bypasses, err := a.metadata.GetActiveCVEBypasses(ctx)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get CVE bypasses")
|
||||
bypasses = []*metadata.CVEBypass{}
|
||||
}
|
||||
|
||||
// Build bypass map for fast lookup
|
||||
bypassedCVEs := make(map[string]*metadata.CVEBypass)
|
||||
packageKey := registry + "/" + name + "@" + version
|
||||
packageKeyNoVersion := registry + "/" + name
|
||||
|
||||
for _, bypass := range bypasses {
|
||||
if bypass.Type == metadata.BypassTypeCVE && bypass.Active {
|
||||
// Check if bypass applies to this package
|
||||
if bypass.AppliesTo == "" || bypass.AppliesTo == packageKey || bypass.AppliesTo == packageKeyNoVersion {
|
||||
bypassedCVEs[strings.ToUpper(bypass.Target)] = bypass
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enrich vulnerabilities with bypass information
|
||||
enrichedVulns := make([]map[string]interface{}, 0, len(scanResult.Vulnerabilities))
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
for _, vuln := range scanResult.Vulnerabilities {
|
||||
bypassed := false
|
||||
var bypassInfo map[string]interface{}
|
||||
|
||||
// Check if this CVE is bypassed
|
||||
if bypass, ok := bypassedCVEs[strings.ToUpper(vuln.ID)]; ok {
|
||||
bypassed = true
|
||||
bypassInfo = map[string]interface{}{
|
||||
"id": bypass.ID,
|
||||
"reason": bypass.Reason,
|
||||
"created_by": bypass.CreatedBy,
|
||||
"expires_at": bypass.ExpiresAt,
|
||||
}
|
||||
} else {
|
||||
// Count non-bypassed vulnerabilities by severity
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
enrichedVuln := map[string]interface{}{
|
||||
"id": vuln.ID,
|
||||
"severity": vuln.Severity,
|
||||
"title": vuln.Title,
|
||||
"description": vuln.Description,
|
||||
"references": vuln.References,
|
||||
"fixed_in": vuln.FixedIn,
|
||||
"bypassed": bypassed,
|
||||
}
|
||||
|
||||
if bypassed {
|
||||
enrichedVuln["bypass"] = bypassInfo
|
||||
}
|
||||
|
||||
enrichedVulns = append(enrichedVulns, enrichedVuln)
|
||||
}
|
||||
|
||||
// Build response
|
||||
response := fiber.Map{
|
||||
"package": fiber.Map{
|
||||
"registry": registry,
|
||||
"name": name,
|
||||
"version": version,
|
||||
},
|
||||
"scanned": true,
|
||||
"scanner": scanResult.Scanner,
|
||||
"scanned_at": scanResult.ScannedAt,
|
||||
"status": scanResult.Status,
|
||||
"vulnerabilities": enrichedVulns,
|
||||
"vulnerability_count": scanResult.VulnerabilityCount,
|
||||
"severity_counts": fiber.Map{
|
||||
"critical": severityCounts["CRITICAL"],
|
||||
"high": severityCounts["HIGH"],
|
||||
"moderate": severityCounts["MODERATE"],
|
||||
"low": severityCounts["LOW"],
|
||||
},
|
||||
"bypassed_count": len(scanResult.Vulnerabilities) - (severityCounts["CRITICAL"] + severityCounts["HIGH"] + severityCounts["MODERATE"] + severityCounts["LOW"]),
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Manager handles authentication and authorization
|
||||
type Manager struct {
|
||||
keys map[string]*APIKey
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// APIKey represents an API key
|
||||
type APIKey struct {
|
||||
ID string
|
||||
Name string
|
||||
HashedKey string
|
||||
Role Role
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
LastUsedAt time.Time
|
||||
Permissions []Permission
|
||||
}
|
||||
|
||||
// Role represents user role
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleReadOnly Role = "readonly"
|
||||
RoleReadWrite Role = "readwrite"
|
||||
RoleAdmin Role = "admin"
|
||||
)
|
||||
|
||||
// Permission represents a specific permission
|
||||
type Permission string
|
||||
|
||||
const (
|
||||
PermissionReadPackage Permission = "package:read"
|
||||
PermissionWritePackage Permission = "package:write"
|
||||
PermissionDeletePackage Permission = "package:delete"
|
||||
PermissionViewStats Permission = "stats:view"
|
||||
PermissionManageKeys Permission = "keys:manage"
|
||||
PermissionManageSettings Permission = "settings:manage"
|
||||
PermissionScanPackages Permission = "scan:execute"
|
||||
PermissionManageBypasses Permission = "bypasses:manage"
|
||||
)
|
||||
|
||||
// New creates a new authentication manager
|
||||
func New() *Manager {
|
||||
return &Manager{
|
||||
keys: make(map[string]*APIKey),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAPIKey generates a new API key
|
||||
func (m *Manager) GenerateAPIKey(name string, role Role, expiresIn *time.Duration) (*APIKey, string, error) {
|
||||
// Generate random key
|
||||
keyBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(keyBytes); err != nil {
|
||||
return nil, "", errors.Wrap(err, errors.ErrCodeInternalServer, "failed to generate random key")
|
||||
}
|
||||
|
||||
rawKey := base64.URLEncoding.EncodeToString(keyBytes)
|
||||
|
||||
// Hash the key
|
||||
hashedKey, err := bcrypt.GenerateFromPassword([]byte(rawKey), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(err, errors.ErrCodeInternalServer, "failed to hash key")
|
||||
}
|
||||
|
||||
var expiresAt *time.Time
|
||||
if expiresIn != nil {
|
||||
t := time.Now().Add(*expiresIn)
|
||||
expiresAt = &t
|
||||
}
|
||||
|
||||
apiKey := &APIKey{
|
||||
ID: generateID(),
|
||||
Name: name,
|
||||
HashedKey: string(hashedKey),
|
||||
Role: role,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: expiresAt,
|
||||
Permissions: getPermissionsForRole(role),
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.keys[apiKey.ID] = apiKey
|
||||
m.mu.Unlock()
|
||||
|
||||
return apiKey, rawKey, nil
|
||||
}
|
||||
|
||||
// ValidateAPIKey validates an API key and returns the associated key object
|
||||
func (m *Manager) ValidateAPIKey(ctx context.Context, rawKey string) (*APIKey, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, apiKey := range m.keys {
|
||||
// Check if key is expired
|
||||
if apiKey.ExpiresAt != nil && time.Now().After(*apiKey.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare hashed key
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(apiKey.HashedKey), []byte(rawKey)); err == nil {
|
||||
// Update last used
|
||||
apiKey.LastUsedAt = time.Now()
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New(errors.ErrCodeUnauthorized, "invalid API key")
|
||||
}
|
||||
|
||||
// RevokeAPIKey revokes an API key
|
||||
func (m *Manager) RevokeAPIKey(keyID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.keys[keyID]; !exists {
|
||||
return errors.NotFound("API key not found")
|
||||
}
|
||||
|
||||
delete(m.keys, keyID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAPIKeys lists all API keys
|
||||
func (m *Manager) ListAPIKeys() []*APIKey {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
keys := make([]*APIKey, 0, len(m.keys))
|
||||
for _, key := range m.keys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// HasPermission checks if an API key has a specific permission
|
||||
func (k *APIKey) HasPermission(permission Permission) bool {
|
||||
for _, p := range k.Permissions {
|
||||
if p == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPermissionsForRole returns permissions for a role
|
||||
func getPermissionsForRole(role Role) []Permission {
|
||||
switch role {
|
||||
case RoleReadOnly:
|
||||
return []Permission{
|
||||
PermissionReadPackage,
|
||||
PermissionViewStats,
|
||||
}
|
||||
case RoleReadWrite:
|
||||
return []Permission{
|
||||
PermissionReadPackage,
|
||||
PermissionWritePackage,
|
||||
PermissionViewStats,
|
||||
}
|
||||
case RoleAdmin:
|
||||
return []Permission{
|
||||
PermissionReadPackage,
|
||||
PermissionWritePackage,
|
||||
PermissionDeletePackage,
|
||||
PermissionViewStats,
|
||||
PermissionManageKeys,
|
||||
PermissionManageSettings,
|
||||
PermissionScanPackages,
|
||||
PermissionManageBypasses,
|
||||
}
|
||||
default:
|
||||
return []Permission{}
|
||||
}
|
||||
}
|
||||
|
||||
// generateID generates a unique ID
|
||||
func generateID() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b) // #nosec G104 -- Rand read always succeeds
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CredentialExtractor extracts authentication credentials from HTTP requests
|
||||
type CredentialExtractor struct{}
|
||||
|
||||
// NewCredentialExtractor creates a new credential extractor
|
||||
func NewCredentialExtractor() *CredentialExtractor {
|
||||
return &CredentialExtractor{}
|
||||
}
|
||||
|
||||
// Extract extracts authentication credentials from an HTTP request
|
||||
// Returns the full Authorization header value or constructed auth string
|
||||
func (e *CredentialExtractor) Extract(r *http.Request) string {
|
||||
// Try Authorization header first (most common)
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
return auth
|
||||
}
|
||||
|
||||
// Try Basic auth from URL (for PyPI compatibility)
|
||||
if username, password, ok := r.BasicAuth(); ok {
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
return "Basic " + auth
|
||||
}
|
||||
|
||||
// No credentials found
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractScheme returns the authentication scheme (Bearer, Basic, Token)
|
||||
func (e *CredentialExtractor) ExtractScheme(r *http.Request) string {
|
||||
auth := e.Extract(r)
|
||||
if auth == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.SplitN(auth, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractToken extracts just the token part (without scheme)
|
||||
func (e *CredentialExtractor) ExtractToken(r *http.Request) string {
|
||||
auth := e.Extract(r)
|
||||
if auth == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove scheme prefix
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
auth = strings.TrimPrefix(auth, "Token ")
|
||||
auth = strings.TrimPrefix(auth, "Basic ")
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
// HasCredentials checks if request has any credentials
|
||||
func (e *CredentialExtractor) HasCredentials(r *http.Request) bool {
|
||||
return e.Extract(r) != ""
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CredentialHasher generates hashes of credentials for cache keys
|
||||
type CredentialHasher struct{}
|
||||
|
||||
// NewCredentialHasher creates a new credential hasher
|
||||
func NewCredentialHasher() *CredentialHasher {
|
||||
return &CredentialHasher{}
|
||||
}
|
||||
|
||||
// Hash generates a short hash of credentials for use in cache keys
|
||||
// Returns "public" if no credentials provided
|
||||
func (h *CredentialHasher) Hash(credentials string) string {
|
||||
if credentials == "" {
|
||||
return "public"
|
||||
}
|
||||
|
||||
// Use SHA256 and take first 16 characters (8 bytes)
|
||||
hash := sha256.Sum256([]byte(credentials))
|
||||
return hex.EncodeToString(hash[:8])
|
||||
}
|
||||
|
||||
// GenerateCacheKey generates a cache key that includes credential hash
|
||||
func (h *CredentialHasher) GenerateCacheKey(registry, packageName, version, credentials string) string {
|
||||
credHash := h.Hash(credentials)
|
||||
return fmt.Sprintf("%s:%s:%s:%s", registry, packageName, version, credHash)
|
||||
}
|
||||
|
||||
// IsPublicKey checks if a cache key is for public packages (no credentials)
|
||||
func (h *CredentialHasher) IsPublicKey(cacheKey string) bool {
|
||||
return len(cacheKey) > 0 && cacheKey[len(cacheKey)-6:] == "public"
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationResult represents a cached credential validation result
|
||||
type ValidationResult struct {
|
||||
Allowed bool
|
||||
ExpiresAt time.Time
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ValidationCache caches credential validation results to reduce upstream checks
|
||||
type ValidationCache struct {
|
||||
cache map[string]*ValidationResult
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewValidationCache creates a new validation cache
|
||||
func NewValidationCache(ttl time.Duration) *ValidationCache {
|
||||
vc := &ValidationCache{
|
||||
cache: make(map[string]*ValidationResult),
|
||||
ttl: ttl,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go vc.cleanupExpired()
|
||||
|
||||
return vc
|
||||
}
|
||||
|
||||
// Get retrieves a validation result from cache
|
||||
// Returns (allowed bool, cached bool, reason string)
|
||||
func (vc *ValidationCache) Get(credHash, packageURL string) (bool, bool, string) {
|
||||
vc.mu.RLock()
|
||||
defer vc.mu.RUnlock()
|
||||
|
||||
key := credHash + ":" + packageURL
|
||||
result, exists := vc.cache[key]
|
||||
|
||||
if !exists {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(result.ExpiresAt) {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
return result.Allowed, true, result.Reason
|
||||
}
|
||||
|
||||
// Set stores a validation result in cache
|
||||
func (vc *ValidationCache) Set(credHash, packageURL string, allowed bool, reason string) {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
key := credHash + ":" + packageURL
|
||||
vc.cache[key] = &ValidationResult{
|
||||
Allowed: allowed,
|
||||
ExpiresAt: time.Now().Add(vc.ttl),
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate removes a specific entry from cache
|
||||
func (vc *ValidationCache) Invalidate(credHash, packageURL string) {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
key := credHash + ":" + packageURL
|
||||
delete(vc.cache, key)
|
||||
}
|
||||
|
||||
// InvalidateAll clears the entire cache
|
||||
func (vc *ValidationCache) InvalidateAll() {
|
||||
vc.mu.Lock()
|
||||
defer vc.mu.Unlock()
|
||||
|
||||
vc.cache = make(map[string]*ValidationResult)
|
||||
}
|
||||
|
||||
// Size returns the number of cached entries
|
||||
func (vc *ValidationCache) Size() int {
|
||||
vc.mu.RLock()
|
||||
defer vc.mu.RUnlock()
|
||||
|
||||
return len(vc.cache)
|
||||
}
|
||||
|
||||
// cleanupExpired removes expired entries periodically
|
||||
func (vc *ValidationCache) cleanupExpired() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
vc.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, result := range vc.cache {
|
||||
if now.After(result.ExpiresAt) {
|
||||
delete(vc.cache, key)
|
||||
}
|
||||
}
|
||||
vc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CredentialValidator validates credentials with upstream registries
|
||||
type CredentialValidator interface {
|
||||
// ValidateAccess checks if credentials grant access to a package
|
||||
// Returns (allowed bool, error)
|
||||
ValidateAccess(ctx context.Context, packageURL string, credentials string) (bool, error)
|
||||
}
|
||||
|
||||
// NPMValidator validates npm registry credentials
|
||||
type NPMValidator struct {
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewNPMValidator creates a new npm credential validator
|
||||
func NewNPMValidator() *NPMValidator {
|
||||
return &NPMValidator{
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAccess validates npm package access using HEAD request
|
||||
func (v *NPMValidator) ValidateAccess(ctx context.Context, packageURL string, credentials string) (bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "HEAD", packageURL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Add credentials if provided
|
||||
if credentials != "" {
|
||||
req.Header.Set("Authorization", credentials)
|
||||
}
|
||||
|
||||
resp, err := v.client.Do(req)
|
||||
if err != nil {
|
||||
// Network error - allow cache fallback with warning
|
||||
log.Warn().Err(err).Str("url", packageURL).Msg("Validation request failed, allowing cache fallback")
|
||||
return true, fmt.Errorf("validation failed: %w (allowing cache fallback)", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Check status code
|
||||
switch resp.StatusCode {
|
||||
case 200, 304:
|
||||
// Access granted
|
||||
return true, nil
|
||||
case 401, 403, 404:
|
||||
// Access denied
|
||||
return false, fmt.Errorf("access denied: HTTP %d", resp.StatusCode)
|
||||
default:
|
||||
// Unexpected status - allow cache fallback with warning
|
||||
log.Warn().Int("status", resp.StatusCode).Str("url", packageURL).Msg("Unexpected validation status, allowing cache fallback")
|
||||
return true, fmt.Errorf("unexpected status %d (allowing cache fallback)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// PyPIValidator validates PyPI registry credentials
|
||||
type PyPIValidator struct {
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewPyPIValidator creates a new PyPI credential validator
|
||||
func NewPyPIValidator() *PyPIValidator {
|
||||
return &PyPIValidator{
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAccess validates PyPI package access using HEAD request
|
||||
func (v *PyPIValidator) ValidateAccess(ctx context.Context, packageURL string, credentials string) (bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "HEAD", packageURL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Add credentials if provided
|
||||
if credentials != "" {
|
||||
req.Header.Set("Authorization", credentials)
|
||||
}
|
||||
|
||||
resp, err := v.client.Do(req)
|
||||
if err != nil {
|
||||
// Network error - allow cache fallback with warning
|
||||
log.Warn().Err(err).Str("url", packageURL).Msg("Validation request failed, allowing cache fallback")
|
||||
return true, fmt.Errorf("validation failed: %w (allowing cache fallback)", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Check status code
|
||||
switch resp.StatusCode {
|
||||
case 200, 304:
|
||||
// Access granted
|
||||
return true, nil
|
||||
case 401, 403, 404:
|
||||
// Access denied
|
||||
return false, fmt.Errorf("access denied: HTTP %d", resp.StatusCode)
|
||||
default:
|
||||
// Unexpected status - allow cache fallback with warning
|
||||
log.Warn().Int("status", resp.StatusCode).Str("url", packageURL).Msg("Unexpected validation status, allowing cache fallback")
|
||||
return true, fmt.Errorf("unexpected status %d (allowing cache fallback)", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// GoValidator validates Go module credentials
|
||||
type GoValidator struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewGoValidator creates a new Go module credential validator
|
||||
func NewGoValidator() *GoValidator {
|
||||
return &GoValidator{
|
||||
timeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAccess validates Go module access using git ls-remote
|
||||
func (v *GoValidator) ValidateAccess(ctx context.Context, modulePath string, credentials string) (bool, error) {
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, v.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Determine repository type and validate accordingly
|
||||
if strings.HasPrefix(modulePath, "github.com/") {
|
||||
return v.validateGitHub(ctx, modulePath, credentials)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(modulePath, "gitlab.com/") {
|
||||
return v.validateGitLab(ctx, modulePath, credentials)
|
||||
}
|
||||
|
||||
// For other Git providers, use generic git validation
|
||||
return v.validateGit(ctx, modulePath, credentials)
|
||||
}
|
||||
|
||||
func (v *GoValidator) validateGitHub(ctx context.Context, modulePath, credentials string) (bool, error) {
|
||||
// Extract token from credentials
|
||||
token := strings.TrimPrefix(credentials, "Bearer ")
|
||||
token = strings.TrimPrefix(token, "Token ")
|
||||
|
||||
if token == "" || token == credentials {
|
||||
// No token provided or not in expected format
|
||||
return false, fmt.Errorf("no GitHub token provided")
|
||||
}
|
||||
|
||||
// Build git URL
|
||||
repoURL := fmt.Sprintf("https://%s.git", modulePath)
|
||||
|
||||
// Create temporary directory for .netrc
|
||||
tempDir, err := os.MkdirTemp("", "gohoarder-validate-*")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create .netrc file with credentials
|
||||
netrcPath := filepath.Join(tempDir, ".netrc")
|
||||
netrcContent := fmt.Sprintf("machine github.com\nlogin oauth2\npassword %s\n", token)
|
||||
if err := os.WriteFile(netrcPath, []byte(netrcContent), 0600); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Run git ls-remote (lightweight, just checks access)
|
||||
cmd := exec.CommandContext(ctx, "git", "ls-remote", repoURL, "HEAD") // #nosec G204 -- git command with validated URL
|
||||
cmd.Env = append(os.Environ(),
|
||||
"HOME="+tempDir, // Use temp .netrc
|
||||
"GIT_TERMINAL_PROMPT=0", // Disable prompts
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Check error message
|
||||
errMsg := string(output)
|
||||
if strings.Contains(errMsg, "could not read Username") ||
|
||||
strings.Contains(errMsg, "Authentication failed") ||
|
||||
strings.Contains(errMsg, "fatal: repository") ||
|
||||
strings.Contains(errMsg, "not found") {
|
||||
// Access denied
|
||||
return false, fmt.Errorf("access denied: %s", strings.TrimSpace(errMsg))
|
||||
}
|
||||
|
||||
// Other error (network, etc.) - allow cache fallback
|
||||
log.Warn().Err(err).Str("module", modulePath).Msg("Git validation failed, allowing cache fallback")
|
||||
return true, fmt.Errorf("validation error (allowing cache): %w", err)
|
||||
}
|
||||
|
||||
// Success - repository accessible
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (v *GoValidator) validateGitLab(ctx context.Context, modulePath, credentials string) (bool, error) {
|
||||
// Extract token from credentials
|
||||
token := strings.TrimPrefix(credentials, "Bearer ")
|
||||
token = strings.TrimPrefix(token, "Token ")
|
||||
token = strings.TrimPrefix(token, "Private-Token ")
|
||||
|
||||
if token == "" || token == credentials {
|
||||
// No token provided
|
||||
return false, fmt.Errorf("no GitLab token provided")
|
||||
}
|
||||
|
||||
// Build git URL
|
||||
repoURL := fmt.Sprintf("https://%s.git", modulePath)
|
||||
|
||||
// Create temporary directory for .netrc
|
||||
tempDir, err := os.MkdirTemp("", "gohoarder-validate-*")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create .netrc file with credentials
|
||||
netrcPath := filepath.Join(tempDir, ".netrc")
|
||||
netrcContent := fmt.Sprintf("machine gitlab.com\nlogin oauth2\npassword %s\n", token)
|
||||
if err := os.WriteFile(netrcPath, []byte(netrcContent), 0600); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Run git ls-remote
|
||||
cmd := exec.CommandContext(ctx, "git", "ls-remote", repoURL, "HEAD") // #nosec G204 -- git command with validated URL
|
||||
cmd.Env = append(os.Environ(),
|
||||
"HOME="+tempDir,
|
||||
"GIT_TERMINAL_PROMPT=0",
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
errMsg := string(output)
|
||||
if strings.Contains(errMsg, "could not read Username") ||
|
||||
strings.Contains(errMsg, "Authentication failed") ||
|
||||
strings.Contains(errMsg, "not found") {
|
||||
return false, fmt.Errorf("access denied: %s", strings.TrimSpace(errMsg))
|
||||
}
|
||||
|
||||
log.Warn().Err(err).Str("module", modulePath).Msg("Git validation failed, allowing cache fallback")
|
||||
return true, fmt.Errorf("validation error (allowing cache): %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (v *GoValidator) validateGit(ctx context.Context, modulePath, credentials string) (bool, error) {
|
||||
// Generic git validation for other providers
|
||||
// Similar to GitHub validation but with generic host detection
|
||||
repoURL := fmt.Sprintf("https://%s.git", modulePath)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "git", "ls-remote", repoURL, "HEAD") // #nosec G204 -- git command with validated URL
|
||||
cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
errMsg := string(output)
|
||||
if strings.Contains(errMsg, "could not read Username") ||
|
||||
strings.Contains(errMsg, "Authentication failed") ||
|
||||
strings.Contains(errMsg, "not found") {
|
||||
return false, fmt.Errorf("access denied: %s", strings.TrimSpace(errMsg))
|
||||
}
|
||||
|
||||
log.Warn().Err(err).Str("module", modulePath).Msg("Git validation failed, allowing cache fallback")
|
||||
return true, fmt.Errorf("validation error (allowing cache): %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
Vendored
+572
@@ -0,0 +1,572 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// ScannerInterface defines the interface for security scanners
|
||||
// Defined here to avoid circular dependency with scanner package
|
||||
type ScannerInterface interface {
|
||||
ScanPackage(ctx context.Context, registry, packageName, version string, filePath string) error
|
||||
CheckVulnerabilities(ctx context.Context, registry, packageName, version string) (blocked bool, reason string, err error)
|
||||
}
|
||||
|
||||
// Manager coordinates caching operations between storage and metadata
|
||||
type Manager struct {
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.MetadataStore
|
||||
scanner ScannerInterface
|
||||
config Config
|
||||
sf singleflight.Group
|
||||
mu sync.RWMutex
|
||||
evicting bool
|
||||
}
|
||||
|
||||
// Config holds cache manager configuration
|
||||
type Config struct {
|
||||
DefaultTTL time.Duration // Default TTL for cached packages
|
||||
CleanupInterval time.Duration // How often to run cleanup
|
||||
EvictionThreshold float64 // Trigger eviction when usage > threshold (0.0-1.0)
|
||||
MaxConcurrent int // Max concurrent upstream fetches
|
||||
}
|
||||
|
||||
// CacheEntry represents a cached package
|
||||
type CacheEntry struct {
|
||||
Package *metadata.Package
|
||||
Data io.ReadCloser
|
||||
FromCache bool
|
||||
UpstreamURL string
|
||||
CacheControl string
|
||||
}
|
||||
|
||||
// New creates a new cache manager
|
||||
func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanner ScannerInterface, config Config) (*Manager, error) {
|
||||
if storage == nil {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "storage backend is required")
|
||||
}
|
||||
|
||||
if metadata == nil {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "metadata store is required")
|
||||
}
|
||||
|
||||
// Scanner is optional - can be nil if security scanning is disabled
|
||||
if scanner != nil {
|
||||
log.Info().Msg("Cache manager initialized with security scanning enabled")
|
||||
}
|
||||
|
||||
if config.DefaultTTL == 0 {
|
||||
config.DefaultTTL = 7 * 24 * time.Hour // 7 days default
|
||||
}
|
||||
|
||||
if config.CleanupInterval == 0 {
|
||||
config.CleanupInterval = 1 * time.Hour
|
||||
}
|
||||
|
||||
if config.EvictionThreshold == 0 {
|
||||
config.EvictionThreshold = 0.9 // 90% full
|
||||
}
|
||||
|
||||
if config.MaxConcurrent == 0 {
|
||||
config.MaxConcurrent = 100
|
||||
}
|
||||
|
||||
manager := &Manager{
|
||||
storage: storage,
|
||||
metadata: metadata,
|
||||
scanner: scanner,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Start background cleanup worker
|
||||
go manager.cleanupWorker()
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// Get retrieves a package from cache or upstream
|
||||
func (m *Manager) Get(ctx context.Context, registry, name, version string, fetchFunc func(context.Context) (io.ReadCloser, string, error)) (*CacheEntry, error) {
|
||||
// Use singleflight to deduplicate concurrent requests
|
||||
key := fmt.Sprintf("%s/%s/%s", registry, name, version)
|
||||
|
||||
result, err, _ := m.sf.Do(key, func() (interface{}, error) {
|
||||
return m.getOrFetch(ctx, registry, name, version, fetchFunc)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.(*CacheEntry), nil
|
||||
}
|
||||
|
||||
// getOrFetch implements the actual get-or-fetch logic
|
||||
func (m *Manager) getOrFetch(ctx context.Context, registry, name, version string, fetchFunc func(context.Context) (io.ReadCloser, string, error)) (*CacheEntry, error) {
|
||||
// Check metadata first
|
||||
pkg, err := m.metadata.GetPackage(ctx, registry, name, version)
|
||||
if err == nil {
|
||||
// Package found in metadata, check if expired
|
||||
if pkg.ExpiresAt != nil && time.Now().After(*pkg.ExpiresAt) {
|
||||
log.Debug().Str("package", name).Str("version", version).Msg("Package expired, re-fetching")
|
||||
metrics.RecordCacheEviction("ttl")
|
||||
// Delete expired package
|
||||
_ = m.deletePackage(ctx, pkg) // #nosec G104 -- Async cleanup
|
||||
} else {
|
||||
// Try to get from storage
|
||||
data, err := m.storage.Get(ctx, pkg.StorageKey)
|
||||
if err == nil {
|
||||
// Cache hit!
|
||||
metrics.RecordCacheHit(registry)
|
||||
_ = m.metadata.UpdateDownloadCount(ctx, registry, name, version) // #nosec G104 -- Async update, error logged
|
||||
|
||||
// Check for vulnerabilities if scanner is enabled
|
||||
if m.scanner != nil {
|
||||
blocked, reason, err := m.scanner.CheckVulnerabilities(ctx, registry, name, version)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("package", name).Msg("Failed to check vulnerabilities")
|
||||
}
|
||||
if blocked {
|
||||
metrics.RecordCacheHit(registry) // Record as blocked
|
||||
_ = data.Close() // #nosec G104 // Close the data reader
|
||||
return nil, errors.New(errors.ErrCodeSecurityViolation, reason)
|
||||
}
|
||||
}
|
||||
|
||||
return &CacheEntry{
|
||||
Package: pkg,
|
||||
Data: data,
|
||||
FromCache: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Storage miss but metadata exists - inconsistency, clean up
|
||||
log.Warn().Str("package", name).Str("version", version).Msg("Metadata exists but storage missing")
|
||||
_ = m.metadata.DeletePackage(ctx, registry, name, version) // #nosec G104 -- Cleanup, error logged
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss - fetch from upstream
|
||||
metrics.RecordCacheMiss(registry)
|
||||
|
||||
if fetchFunc == nil {
|
||||
return nil, errors.NotFound(fmt.Sprintf("package not found and no fetch function provided: %s/%s@%s", registry, name, version))
|
||||
}
|
||||
|
||||
log.Debug().Str("package", name).Str("version", version).Msg("Fetching from upstream")
|
||||
|
||||
// Fetch from upstream
|
||||
data, upstreamURL, err := fetchFunc(ctx)
|
||||
if err != nil {
|
||||
metrics.RecordUpstreamRequest(registry, "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeUpstreamFailure, "failed to fetch from upstream")
|
||||
}
|
||||
defer data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
metrics.RecordUpstreamRequest(registry, "success")
|
||||
|
||||
// Store in cache (this will also trigger background scan)
|
||||
storedPkg, err := m.store(ctx, registry, name, version, data, upstreamURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait briefly for initial scan to complete if scanner is enabled
|
||||
// This prevents serving vulnerable packages on first request
|
||||
if m.scanner != nil {
|
||||
// Wait up to 30 seconds for scan to complete
|
||||
scanCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-scanCtx.Done():
|
||||
// Timeout or context cancelled - proceed anyway
|
||||
// Package is cached, will be blocked on next request if vulnerable
|
||||
log.Warn().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Msg("Scan timeout - allowing first download, will block on subsequent requests if vulnerable")
|
||||
goto servePkg
|
||||
|
||||
case <-ticker.C:
|
||||
// First check if scan has completed by checking the SecurityScanned flag
|
||||
// This prevents race condition where CheckVulnerabilities() returns "clean"
|
||||
// before all scanners have finished
|
||||
pkg, err := m.metadata.GetPackage(scanCtx, registry, name, version)
|
||||
if err != nil {
|
||||
// Failed to get package metadata - continue waiting
|
||||
log.Debug().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Err(err).
|
||||
Msg("Failed to get package metadata, waiting...")
|
||||
continue
|
||||
}
|
||||
|
||||
if !pkg.SecurityScanned {
|
||||
// Scan still in progress - continue waiting
|
||||
log.Debug().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Msg("Scan in progress, waiting...")
|
||||
continue
|
||||
}
|
||||
|
||||
// Scan completed - now check if package should be blocked
|
||||
blocked, reason, err := m.scanner.CheckVulnerabilities(scanCtx, registry, name, version)
|
||||
if err != nil {
|
||||
// Unexpected error after scan complete - log and continue waiting
|
||||
log.Warn().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Err(err).
|
||||
Msg("Error checking vulnerabilities, waiting...")
|
||||
continue
|
||||
}
|
||||
|
||||
// Scan completed - check if blocked
|
||||
if blocked {
|
||||
log.Info().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Str("reason", reason).
|
||||
Msg("Package cached but blocked due to vulnerabilities")
|
||||
return nil, errors.New(errors.ErrCodeSecurityViolation, reason)
|
||||
}
|
||||
|
||||
// Package is clean - proceed to serve
|
||||
log.Info().
|
||||
Str("package", name).
|
||||
Str("version", version).
|
||||
Msg("Scan completed, package is clean")
|
||||
goto servePkg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
servePkg:
|
||||
// Re-open from storage for consistency
|
||||
storedData, err := m.storage.Get(ctx, storedPkg.StorageKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to retrieve just-stored package")
|
||||
}
|
||||
|
||||
return &CacheEntry{
|
||||
Package: storedPkg,
|
||||
Data: storedData,
|
||||
FromCache: false,
|
||||
UpstreamURL: upstreamURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// store stores a package in cache
|
||||
func (m *Manager) store(ctx context.Context, registry, name, version string, data io.ReadCloser, upstreamURL string) (*metadata.Package, error) {
|
||||
// Generate storage key
|
||||
storageKey := m.generateStorageKey(registry, name, version)
|
||||
|
||||
// Calculate checksums while storing
|
||||
// We need to read the data, calculate checksums, and store it
|
||||
// This requires buffering the data
|
||||
var buf []byte
|
||||
var err error
|
||||
|
||||
// Read all data
|
||||
buf, err = io.ReadAll(data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeUpstreamFailure, "failed to read upstream data")
|
||||
}
|
||||
|
||||
// Calculate checksums
|
||||
h := sha256.New()
|
||||
h.Write(buf)
|
||||
checksumSHA256 := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
size := int64(len(buf))
|
||||
|
||||
// Check quota before storing
|
||||
quota, err := m.storage.GetQuota(ctx)
|
||||
if err == nil && quota.Limit > 0 {
|
||||
if quota.Used+size > quota.Limit {
|
||||
// Trigger eviction
|
||||
if err := m.evict(ctx, size); err != nil {
|
||||
return nil, errors.QuotaExceeded(quota.Limit)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store in storage backend
|
||||
opts := &storage.PutOptions{
|
||||
ChecksumSHA256: checksumSHA256,
|
||||
}
|
||||
|
||||
err = m.storage.Put(ctx, storageKey, io.NopCloser(bytes.NewReader(buf)), opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create metadata entry
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(m.config.DefaultTTL)
|
||||
|
||||
pkg := &metadata.Package{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
Name: name,
|
||||
Version: version,
|
||||
StorageKey: storageKey,
|
||||
Size: size,
|
||||
ChecksumSHA256: checksumSHA256,
|
||||
UpstreamURL: upstreamURL,
|
||||
CachedAt: now,
|
||||
LastAccessed: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
DownloadCount: 0,
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
// Save metadata
|
||||
if err := m.metadata.SavePackage(ctx, pkg); err != nil {
|
||||
// Clean up storage if metadata save fails
|
||||
_ = m.storage.Delete(ctx, storageKey) // #nosec G104 -- Cleanup, error logged
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Scan package if scanner is enabled (run in background to not block cache operations)
|
||||
if m.scanner != nil {
|
||||
go func() {
|
||||
scanCtx := context.Background()
|
||||
var filePath string
|
||||
var cleanupFunc func()
|
||||
|
||||
// Check if storage backend supports local paths
|
||||
if localProvider, ok := m.storage.(interface {
|
||||
GetLocalPath(ctx context.Context, key string) (string, error)
|
||||
}); ok {
|
||||
// Use direct file path from storage (avoid double download)
|
||||
path, err := localProvider.GetLocalPath(scanCtx, storageKey)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("package", name).Msg("Failed to get local path for scanning")
|
||||
return
|
||||
}
|
||||
filePath = path
|
||||
cleanupFunc = func() {} // No cleanup needed for direct path
|
||||
log.Debug().Str("package", name).Str("path", filePath).Msg("Scanning package from storage path")
|
||||
} else {
|
||||
// Fallback: Create temp file for remote storage (S3, SMB, etc.)
|
||||
tempFilePath := filepath.Join(os.TempDir(), storageKey)
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if err := os.MkdirAll(filepath.Dir(tempFilePath), 0750); err != nil {
|
||||
log.Error().Err(err).Str("package", name).Msg("Failed to create temp directory for scanning")
|
||||
return
|
||||
}
|
||||
|
||||
tempFile, err := os.Create(tempFilePath) // #nosec G304 -- Temp file path is constructed from validated package name
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("package", name).Msg("Failed to create temp file for scanning")
|
||||
return
|
||||
}
|
||||
|
||||
// Write package data to temp file
|
||||
if _, err := tempFile.Write(buf); err != nil {
|
||||
tempFile.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
_ = os.Remove(tempFilePath) // #nosec G104 -- Cleanup, error not critical
|
||||
log.Error().Err(err).Str("package", name).Msg("Failed to write temp file for scanning")
|
||||
return
|
||||
}
|
||||
tempFile.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
filePath = tempFilePath
|
||||
cleanupFunc = func() { _ = os.Remove(tempFilePath) } // #nosec G104 -- Cleanup
|
||||
log.Debug().Str("package", name).Str("path", filePath).Msg("Scanning package from temp file")
|
||||
}
|
||||
|
||||
defer cleanupFunc()
|
||||
|
||||
// Scan package
|
||||
if err := m.scanner.ScanPackage(scanCtx, registry, name, version, filePath); err != nil {
|
||||
log.Error().Err(err).Str("package", name).Msg("Failed to scan package")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return pkg, nil
|
||||
}
|
||||
|
||||
// Delete removes a package from cache
|
||||
func (m *Manager) Delete(ctx context.Context, registry, name, version string) error {
|
||||
pkg, err := m.metadata.GetPackage(ctx, registry, name, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.deletePackage(ctx, pkg)
|
||||
}
|
||||
|
||||
// deletePackage deletes a package from both storage and metadata
|
||||
func (m *Manager) deletePackage(ctx context.Context, pkg *metadata.Package) error {
|
||||
// Delete from storage
|
||||
if err := m.storage.Delete(ctx, pkg.StorageKey); err != nil {
|
||||
log.Warn().Err(err).Str("key", pkg.StorageKey).Msg("Failed to delete from storage")
|
||||
}
|
||||
|
||||
// Delete from metadata
|
||||
return m.metadata.DeletePackage(ctx, pkg.Registry, pkg.Name, pkg.Version)
|
||||
}
|
||||
|
||||
// evict implements LRU eviction
|
||||
func (m *Manager) evict(ctx context.Context, needed int64) error {
|
||||
m.mu.Lock()
|
||||
if m.evicting {
|
||||
m.mu.Unlock()
|
||||
return errors.New(errors.ErrCodeStorageFailure, "eviction already in progress")
|
||||
}
|
||||
m.evicting = true
|
||||
m.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.evicting = false
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
|
||||
log.Info().Int64("needed", needed).Msg("Starting LRU eviction")
|
||||
|
||||
// List packages sorted by last accessed (oldest first)
|
||||
opts := &metadata.ListOptions{
|
||||
SortBy: "last_accessed",
|
||||
SortDesc: false,
|
||||
Limit: 100,
|
||||
}
|
||||
|
||||
var freed int64
|
||||
for freed < needed {
|
||||
packages, err := m.metadata.ListPackages(ctx, opts)
|
||||
if err != nil || len(packages) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, pkg := range packages {
|
||||
if err := m.deletePackage(ctx, pkg); err != nil {
|
||||
log.Warn().Err(err).Str("package", pkg.Name).Msg("Failed to evict package")
|
||||
continue
|
||||
}
|
||||
|
||||
freed += pkg.Size
|
||||
metrics.RecordCacheEviction("lru")
|
||||
|
||||
if freed >= needed {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(packages) < opts.Limit {
|
||||
break // No more packages
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Int64("freed", freed).Msg("Eviction completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupWorker runs periodic cleanup of expired packages
|
||||
func (m *Manager) cleanupWorker() {
|
||||
ticker := time.NewTicker(m.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
ctx := context.Background()
|
||||
m.cleanup(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes expired packages
|
||||
func (m *Manager) cleanup(ctx context.Context) {
|
||||
log.Debug().Msg("Starting cleanup worker")
|
||||
|
||||
// List all packages
|
||||
packages, err := m.metadata.ListPackages(ctx, &metadata.ListOptions{})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages for cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var cleaned int
|
||||
|
||||
for _, pkg := range packages {
|
||||
if pkg.ExpiresAt != nil && now.After(*pkg.ExpiresAt) {
|
||||
if err := m.deletePackage(ctx, pkg); err != nil {
|
||||
log.Warn().Err(err).Str("package", pkg.Name).Msg("Failed to clean up expired package")
|
||||
continue
|
||||
}
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if cleaned > 0 {
|
||||
log.Info().Int("count", cleaned).Msg("Cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// generateStorageKey generates a storage key for a package
|
||||
func (m *Manager) generateStorageKey(registry, name, version string) string {
|
||||
return fmt.Sprintf("%s/%s/%s", registry, name, version)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *Manager) GetStats(ctx context.Context, registry string) (*metadata.Stats, error) {
|
||||
return m.metadata.GetStats(ctx, registry)
|
||||
}
|
||||
|
||||
// Health checks cache manager health
|
||||
func (m *Manager) Health(ctx context.Context) error {
|
||||
// Check storage health
|
||||
if err := m.storage.Health(ctx); err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "storage health check failed")
|
||||
}
|
||||
|
||||
// Check metadata health
|
||||
if err := m.metadata.Health(ctx); err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeDatabaseFailure, "metadata health check failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the cache manager
|
||||
func (m *Manager) Close() error {
|
||||
var err error
|
||||
|
||||
if closeErr := m.storage.Close(); closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
|
||||
if closeErr := m.metadata.Close(); closeErr != nil {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%w; %w", err, closeErr)
|
||||
} else {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
Vendored
+980
@@ -0,0 +1,980 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockStorageBackend is a mock for storage.StorageBackend
|
||||
type MockStorageBackend struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
args := m.Called(ctx, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(io.ReadCloser), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
args := m.Called(ctx, key, data, opts)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Delete(ctx context.Context, key string) error {
|
||||
args := m.Called(ctx, key)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
args := m.Called(ctx, key)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
args := m.Called(ctx, prefix, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]storage.StorageObject), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
args := m.Called(ctx, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*storage.StorageInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
args := m.Called(ctx)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*storage.QuotaInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Health(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockStorageBackend) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// MockMetadataStore is a mock for metadata.MetadataStore
|
||||
type MockMetadataStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) SavePackage(ctx context.Context, pkg *metadata.Package) error {
|
||||
args := m.Called(ctx, pkg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) GetPackage(ctx context.Context, registry, name, version string) (*metadata.Package, error) {
|
||||
args := m.Called(ctx, registry, name, version)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*metadata.Package), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) DeletePackage(ctx context.Context, registry, name, version string) error {
|
||||
args := m.Called(ctx, registry, name, version)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) ListPackages(ctx context.Context, opts *metadata.ListOptions) ([]*metadata.Package, error) {
|
||||
args := m.Called(ctx, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]*metadata.Package), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) UpdateDownloadCount(ctx context.Context, registry, name, version string) error {
|
||||
args := m.Called(ctx, registry, name, version)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) GetStats(ctx context.Context, registry string) (*metadata.Stats, error) {
|
||||
args := m.Called(ctx, registry)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*metadata.Stats), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) SaveScanResult(ctx context.Context, result *metadata.ScanResult) error {
|
||||
args := m.Called(ctx, result)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) GetScanResult(ctx context.Context, registry, name, version string) (*metadata.ScanResult, error) {
|
||||
args := m.Called(ctx, registry, name, version)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*metadata.ScanResult), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) Count(ctx context.Context) (int, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) Health(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) SaveCVEBypass(ctx context.Context, bypass *metadata.CVEBypass) error {
|
||||
args := m.Called(ctx, bypass)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) GetActiveCVEBypasses(ctx context.Context) ([]*metadata.CVEBypass, error) {
|
||||
args := m.Called(ctx)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]*metadata.CVEBypass), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) ListCVEBypasses(ctx context.Context, opts *metadata.BypassListOptions) ([]*metadata.CVEBypass, error) {
|
||||
args := m.Called(ctx, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]*metadata.CVEBypass), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) DeleteCVEBypass(ctx context.Context, id string) error {
|
||||
args := m.Called(ctx, id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) CleanupExpiredBypasses(ctx context.Context) (int, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) GetTimeSeriesStats(ctx context.Context, period string, registry string) (*metadata.TimeSeriesStats, error) {
|
||||
args := m.Called(ctx, period, registry)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*metadata.TimeSeriesStats), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockMetadataStore) AggregateDownloadData(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// TestNew tests cache manager creation
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.MetadataStore
|
||||
config Config
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Valid configuration
|
||||
{
|
||||
name: "valid config with defaults",
|
||||
storage: &MockStorageBackend{},
|
||||
metadata: &MockMetadataStore{},
|
||||
config: Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with custom settings",
|
||||
storage: &MockStorageBackend{},
|
||||
metadata: &MockMetadataStore{},
|
||||
config: Config{
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
CleanupInterval: 30 * time.Minute,
|
||||
EvictionThreshold: 0.8,
|
||||
MaxConcurrent: 50,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Missing required components
|
||||
{
|
||||
name: "nil storage",
|
||||
storage: nil,
|
||||
metadata: &MockMetadataStore{},
|
||||
config: Config{},
|
||||
wantErr: true,
|
||||
errContains: "storage backend is required",
|
||||
},
|
||||
{
|
||||
name: "nil metadata",
|
||||
storage: &MockStorageBackend{},
|
||||
metadata: nil,
|
||||
config: Config{},
|
||||
wantErr: true,
|
||||
errContains: "metadata store is required",
|
||||
},
|
||||
// EDGE: Both nil
|
||||
{
|
||||
name: "both nil",
|
||||
storage: nil,
|
||||
metadata: nil,
|
||||
config: Config{},
|
||||
wantErr: true,
|
||||
errContains: "storage backend is required",
|
||||
},
|
||||
// EDGE: Zero values get defaults
|
||||
{
|
||||
name: "zero config gets defaults",
|
||||
storage: &MockStorageBackend{},
|
||||
metadata: &MockMetadataStore{},
|
||||
config: Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := New(tt.storage, tt.metadata, nil, tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, manager)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
// Verify defaults were set
|
||||
if tt.config.DefaultTTL == 0 {
|
||||
assert.Equal(t, 7*24*time.Hour, manager.config.DefaultTTL)
|
||||
}
|
||||
if tt.config.CleanupInterval == 0 {
|
||||
assert.Equal(t, 1*time.Hour, manager.config.CleanupInterval)
|
||||
}
|
||||
if tt.config.EvictionThreshold == 0 {
|
||||
assert.Equal(t, 0.9, manager.config.EvictionThreshold)
|
||||
}
|
||||
if tt.config.MaxConcurrent == 0 {
|
||||
assert.Equal(t, 100, manager.config.MaxConcurrent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGet tests cache retrieval with various scenarios
|
||||
func TestGet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
packageName string
|
||||
version string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
fetchFunc func(context.Context) (io.ReadCloser, string, error)
|
||||
wantFromCache bool
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Cache hit
|
||||
{
|
||||
name: "cache hit - package exists and valid",
|
||||
registry: "npm",
|
||||
packageName: "react",
|
||||
version: "18.2.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(24 * time.Hour)
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "react",
|
||||
Version: "18.2.0",
|
||||
StorageKey: "npm/react/18.2.0",
|
||||
CachedAt: now,
|
||||
LastAccessed: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "react", "18.2.0").Return(pkg, nil)
|
||||
s.On("Get", mock.Anything, "npm/react/18.2.0").Return(io.NopCloser(strings.NewReader("cached data")), nil)
|
||||
m.On("UpdateDownloadCount", mock.Anything, "npm", "react", "18.2.0").Return(nil)
|
||||
},
|
||||
wantFromCache: true,
|
||||
wantErr: false,
|
||||
},
|
||||
// GOOD: Cache miss - fetch from upstream
|
||||
{
|
||||
name: "cache miss - fetch from upstream",
|
||||
registry: "npm",
|
||||
packageName: "lodash",
|
||||
version: "4.17.21",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "lodash", "4.17.21").Return(nil, errors.New("not found"))
|
||||
s.On("GetQuota", mock.Anything).Return(&storage.QuotaInfo{Used: 100, Available: 900, Limit: 1000}, nil)
|
||||
s.On("Put", mock.Anything, "npm/lodash/4.17.21", mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("SavePackage", mock.Anything, mock.Anything).Return(nil)
|
||||
s.On("Get", mock.Anything, "npm/lodash/4.17.21").Return(io.NopCloser(strings.NewReader("upstream data")), nil)
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return io.NopCloser(strings.NewReader("upstream data")), "https://registry.npmjs.org/lodash", nil
|
||||
},
|
||||
wantFromCache: false,
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Expired package
|
||||
{
|
||||
name: "expired package - re-fetch",
|
||||
registry: "npm",
|
||||
packageName: "expired-pkg",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(-1 * time.Hour) // Expired 1 hour ago
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "expired-pkg",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/expired-pkg/1.0.0",
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "expired-pkg", "1.0.0").Return(pkg, nil)
|
||||
m.On("DeletePackage", mock.Anything, "npm", "expired-pkg", "1.0.0").Return(nil)
|
||||
s.On("Delete", mock.Anything, "npm/expired-pkg/1.0.0").Return(nil)
|
||||
s.On("GetQuota", mock.Anything).Return(&storage.QuotaInfo{Used: 100, Available: 900, Limit: 1000}, nil)
|
||||
s.On("Put", mock.Anything, "npm/expired-pkg/1.0.0", mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("SavePackage", mock.Anything, mock.Anything).Return(nil)
|
||||
s.On("Get", mock.Anything, "npm/expired-pkg/1.0.0").Return(io.NopCloser(strings.NewReader("refreshed data")), nil)
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return io.NopCloser(strings.NewReader("refreshed data")), "https://registry.npmjs.org/expired-pkg", nil
|
||||
},
|
||||
wantFromCache: false,
|
||||
wantErr: false,
|
||||
},
|
||||
// BAD: Fetch function is nil and package not cached
|
||||
{
|
||||
name: "nil fetch function and not cached",
|
||||
registry: "npm",
|
||||
packageName: "missing",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "missing", "1.0.0").Return(nil, errors.New("not found"))
|
||||
},
|
||||
fetchFunc: nil,
|
||||
wantErr: true,
|
||||
errContains: "package not found and no fetch function provided",
|
||||
},
|
||||
// BAD: Upstream fetch fails
|
||||
{
|
||||
name: "upstream fetch error",
|
||||
registry: "npm",
|
||||
packageName: "fail-pkg",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "fail-pkg", "1.0.0").Return(nil, errors.New("not found"))
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return nil, "", errors.New("upstream error")
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "failed to fetch from upstream",
|
||||
},
|
||||
// EDGE: Metadata exists but storage missing
|
||||
{
|
||||
name: "metadata exists but storage missing - inconsistency",
|
||||
registry: "npm",
|
||||
packageName: "inconsistent",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(24 * time.Hour)
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "inconsistent",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/inconsistent/1.0.0",
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "inconsistent", "1.0.0").Return(pkg, nil)
|
||||
// First Get fails (storage missing)
|
||||
s.On("Get", mock.Anything, "npm/inconsistent/1.0.0").Return(nil, errors.New("not found")).Once()
|
||||
m.On("DeletePackage", mock.Anything, "npm", "inconsistent", "1.0.0").Return(nil)
|
||||
s.On("GetQuota", mock.Anything).Return(&storage.QuotaInfo{Used: 100, Available: 900, Limit: 1000}, nil)
|
||||
s.On("Put", mock.Anything, "npm/inconsistent/1.0.0", mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("SavePackage", mock.Anything, mock.Anything).Return(nil)
|
||||
// Second Get succeeds (after re-storing)
|
||||
s.On("Get", mock.Anything, "npm/inconsistent/1.0.0").Return(io.NopCloser(strings.NewReader("recovered data")), nil).Once()
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return io.NopCloser(strings.NewReader("recovered data")), "https://registry.npmjs.org/inconsistent", nil
|
||||
},
|
||||
wantFromCache: false,
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Storage save fails
|
||||
{
|
||||
name: "storage save fails",
|
||||
registry: "npm",
|
||||
packageName: "save-fail",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "save-fail", "1.0.0").Return(nil, errors.New("not found"))
|
||||
s.On("GetQuota", mock.Anything).Return(&storage.QuotaInfo{Used: 100, Available: 900, Limit: 1000}, nil)
|
||||
s.On("Put", mock.Anything, "npm/save-fail/1.0.0", mock.Anything, mock.Anything).Return(errors.New("storage error"))
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return io.NopCloser(strings.NewReader("data")), "https://registry.npmjs.org/save-fail", nil
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "storage error",
|
||||
},
|
||||
// EDGE: Metadata save fails (should cleanup storage)
|
||||
{
|
||||
name: "metadata save fails - storage cleanup",
|
||||
registry: "npm",
|
||||
packageName: "meta-fail",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "meta-fail", "1.0.0").Return(nil, errors.New("not found"))
|
||||
s.On("GetQuota", mock.Anything).Return(&storage.QuotaInfo{Used: 100, Available: 900, Limit: 1000}, nil)
|
||||
s.On("Put", mock.Anything, "npm/meta-fail/1.0.0", mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("SavePackage", mock.Anything, mock.Anything).Return(errors.New("metadata error"))
|
||||
s.On("Delete", mock.Anything, "npm/meta-fail/1.0.0").Return(nil)
|
||||
},
|
||||
fetchFunc: func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
return io.NopCloser(strings.NewReader("data")), "https://registry.npmjs.org/meta-fail", nil
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "metadata error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
if tt.setupMock != nil {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
entry, err := manager.Get(ctx, tt.registry, tt.packageName, tt.version, tt.fetchFunc)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, entry)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entry)
|
||||
assert.Equal(t, tt.wantFromCache, entry.FromCache)
|
||||
assert.NotNil(t, entry.Data)
|
||||
// Read and verify data exists
|
||||
data, _ := io.ReadAll(entry.Data)
|
||||
assert.NotEmpty(t, data)
|
||||
}
|
||||
|
||||
mockStorage.AssertExpectations(t)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDelete tests package deletion
|
||||
func TestDelete(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
packageName string
|
||||
version string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Successful deletion
|
||||
{
|
||||
name: "successful deletion",
|
||||
registry: "npm",
|
||||
packageName: "react",
|
||||
version: "18.2.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "react",
|
||||
Version: "18.2.0",
|
||||
StorageKey: "npm/react/18.2.0",
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "react", "18.2.0").Return(pkg, nil)
|
||||
s.On("Delete", mock.Anything, "npm/react/18.2.0").Return(nil)
|
||||
m.On("DeletePackage", mock.Anything, "npm", "react", "18.2.0").Return(nil)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Package not found
|
||||
{
|
||||
name: "package not found",
|
||||
registry: "npm",
|
||||
packageName: "missing",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("GetPackage", mock.Anything, "npm", "missing", "1.0.0").Return(nil, errors.New("not found"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "not found",
|
||||
},
|
||||
// EDGE: Storage delete fails but metadata succeeds
|
||||
{
|
||||
name: "storage delete fails",
|
||||
registry: "npm",
|
||||
packageName: "storage-fail",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "storage-fail",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/storage-fail/1.0.0",
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "storage-fail", "1.0.0").Return(pkg, nil)
|
||||
s.On("Delete", mock.Anything, "npm/storage-fail/1.0.0").Return(errors.New("storage error"))
|
||||
m.On("DeletePackage", mock.Anything, "npm", "storage-fail", "1.0.0").Return(nil)
|
||||
},
|
||||
wantErr: false, // Metadata delete still succeeds
|
||||
},
|
||||
// EDGE: Metadata delete fails
|
||||
{
|
||||
name: "metadata delete fails",
|
||||
registry: "npm",
|
||||
packageName: "meta-fail",
|
||||
version: "1.0.0",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "meta-fail",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/meta-fail/1.0.0",
|
||||
}
|
||||
m.On("GetPackage", mock.Anything, "npm", "meta-fail", "1.0.0").Return(pkg, nil)
|
||||
s.On("Delete", mock.Anything, "npm/meta-fail/1.0.0").Return(nil)
|
||||
m.On("DeletePackage", mock.Anything, "npm", "meta-fail", "1.0.0").Return(errors.New("metadata error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "metadata error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
if tt.setupMock != nil {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.Delete(ctx, tt.registry, tt.packageName, tt.version)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
mockStorage.AssertExpectations(t)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHealth tests health check functionality
|
||||
func TestHealth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Both healthy
|
||||
{
|
||||
name: "both storage and metadata healthy",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Health", mock.Anything).Return(nil)
|
||||
m.On("Health", mock.Anything).Return(nil)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Storage unhealthy
|
||||
{
|
||||
name: "storage unhealthy",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Health", mock.Anything).Return(errors.New("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "storage health check failed",
|
||||
},
|
||||
// WRONG: Metadata unhealthy
|
||||
{
|
||||
name: "metadata unhealthy",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Health", mock.Anything).Return(nil)
|
||||
m.On("Health", mock.Anything).Return(errors.New("metadata error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "metadata health check failed",
|
||||
},
|
||||
// BAD: Both unhealthy
|
||||
{
|
||||
name: "both unhealthy",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Health", mock.Anything).Return(errors.New("storage error"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "storage health check failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
if tt.setupMock != nil {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.Health(ctx)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
mockStorage.AssertExpectations(t)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStats tests statistics retrieval
|
||||
func TestGetStats(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
expectedStats := &metadata.Stats{
|
||||
Registry: "npm",
|
||||
TotalPackages: 100,
|
||||
TotalSize: 1024 * 1024 * 100,
|
||||
TotalDownloads: 5000,
|
||||
}
|
||||
|
||||
mockMetadata.On("GetStats", mock.Anything, "npm").Return(expectedStats, nil)
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
stats, err := manager.GetStats(ctx, "npm")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedStats, stats)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestClose tests manager cleanup
|
||||
func TestClose(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Clean close
|
||||
{
|
||||
name: "both close successfully",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Close").Return(nil)
|
||||
m.On("Close").Return(nil)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Storage close fails
|
||||
{
|
||||
name: "storage close fails",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Close").Return(errors.New("storage error"))
|
||||
m.On("Close").Return(nil)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
// WRONG: Metadata close fails
|
||||
{
|
||||
name: "metadata close fails",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Close").Return(nil)
|
||||
m.On("Close").Return(errors.New("metadata error"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
// BAD: Both close fail
|
||||
{
|
||||
name: "both close fail",
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
s.On("Close").Return(errors.New("storage error"))
|
||||
m.On("Close").Return(errors.New("metadata error"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
if tt.setupMock != nil {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
mockStorage.AssertExpectations(t)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvict tests LRU eviction
|
||||
func TestEvict(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
needed int64
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Successful eviction
|
||||
{
|
||||
name: "evict enough to free space",
|
||||
needed: 200,
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
packages := []*metadata.Package{
|
||||
{
|
||||
ID: "1",
|
||||
Name: "old-pkg-1",
|
||||
Version: "1.0.0",
|
||||
Registry: "npm",
|
||||
StorageKey: "npm/old-pkg-1/1.0.0",
|
||||
Size: 100,
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Name: "old-pkg-2",
|
||||
Version: "1.0.0",
|
||||
Registry: "npm",
|
||||
StorageKey: "npm/old-pkg-2/1.0.0",
|
||||
Size: 150,
|
||||
},
|
||||
}
|
||||
m.On("ListPackages", mock.Anything, mock.MatchedBy(func(opts *metadata.ListOptions) bool {
|
||||
return opts.SortBy == "last_accessed" && !opts.SortDesc
|
||||
})).Return(packages, nil).Once()
|
||||
|
||||
s.On("Delete", mock.Anything, "npm/old-pkg-1/1.0.0").Return(nil)
|
||||
m.On("DeletePackage", mock.Anything, "npm", "old-pkg-1", "1.0.0").Return(nil)
|
||||
s.On("Delete", mock.Anything, "npm/old-pkg-2/1.0.0").Return(nil)
|
||||
m.On("DeletePackage", mock.Anything, "npm", "old-pkg-2", "1.0.0").Return(nil)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: No packages to evict
|
||||
{
|
||||
name: "no packages available to evict",
|
||||
needed: 100,
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("ListPackages", mock.Anything, mock.Anything).Return([]*metadata.Package{}, nil)
|
||||
},
|
||||
wantErr: false, // Doesn't error, just can't free enough
|
||||
},
|
||||
// EDGE: Eviction list error
|
||||
{
|
||||
name: "list packages fails",
|
||||
needed: 100,
|
||||
setupMock: func(s *MockStorageBackend, m *MockMetadataStore) {
|
||||
m.On("ListPackages", mock.Anything, mock.Anything).Return(nil, errors.New("list error"))
|
||||
},
|
||||
wantErr: false, // Doesn't error, just can't complete
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
if tt.setupMock != nil {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
err = manager.evict(ctx, tt.needed)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
mockStorage.AssertExpectations(t)
|
||||
mockMetadata.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateStorageKey tests storage key generation
|
||||
func TestGenerateStorageKey(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
registry string
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"npm", "react", "18.2.0", "npm/react/18.2.0"},
|
||||
{"pypi", "requests", "2.28.0", "pypi/requests/2.28.0"},
|
||||
{"go", "github.com/gin-gonic/gin", "v1.9.0", "go/github.com/gin-gonic/gin/v1.9.0"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
key := manager.generateStorageKey(tt.registry, tt.name, tt.version)
|
||||
assert.Equal(t, tt.expected, key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentGet tests concurrent access doesn't cause data races
|
||||
func TestConcurrentGet(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
// Setup mocks for concurrent access
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(24 * time.Hour)
|
||||
pkg := &metadata.Package{
|
||||
ID: "test-id",
|
||||
Registry: "npm",
|
||||
Name: "concurrent",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/concurrent/1.0.0",
|
||||
CachedAt: now,
|
||||
LastAccessed: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
}
|
||||
|
||||
// Use Maybe() to allow variable number of calls due to singleflight deduplication
|
||||
mockMetadata.On("GetPackage", mock.Anything, "npm", "concurrent", "1.0.0").Return(pkg, nil).Maybe()
|
||||
mockStorage.On("Get", mock.Anything, "npm/concurrent/1.0.0").Return(
|
||||
io.NopCloser(bytes.NewReader([]byte("test data"))), nil).Maybe()
|
||||
mockMetadata.On("UpdateDownloadCount", mock.Anything, "npm", "concurrent", "1.0.0").Return(nil).Maybe()
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
const numGoroutines = 10
|
||||
|
||||
// Run concurrent gets
|
||||
errs := make(chan error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
_, err := manager.Get(ctx, "npm", "concurrent", "1.0.0", nil)
|
||||
errs <- err
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errs
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify at least one call was made (singleflight may deduplicate others)
|
||||
mockMetadata.AssertCalled(t, "GetPackage", mock.Anything, "npm", "concurrent", "1.0.0")
|
||||
}
|
||||
+360
@@ -0,0 +1,360 @@
|
||||
package cdn
|
||||
|
||||
import (
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for ETag generation, not cryptographic security
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CacheControl represents cache control directives
|
||||
type CacheControl struct {
|
||||
MaxAge int // max-age in seconds
|
||||
SMaxAge int // s-maxage in seconds (for shared caches)
|
||||
Public bool // public directive
|
||||
Private bool // private directive
|
||||
NoCache bool // no-cache directive
|
||||
NoStore bool // no-store directive
|
||||
MustRevalidate bool // must-revalidate directive
|
||||
ProxyRevalidate bool // proxy-revalidate directive
|
||||
Immutable bool // immutable directive
|
||||
StaleWhileRevalidate int // stale-while-revalidate in seconds
|
||||
}
|
||||
|
||||
// String returns the Cache-Control header value
|
||||
func (cc CacheControl) String() string {
|
||||
var parts []string
|
||||
|
||||
if cc.Public {
|
||||
parts = append(parts, "public")
|
||||
}
|
||||
if cc.Private {
|
||||
parts = append(parts, "private")
|
||||
}
|
||||
if cc.NoCache {
|
||||
parts = append(parts, "no-cache")
|
||||
}
|
||||
if cc.NoStore {
|
||||
parts = append(parts, "no-store")
|
||||
}
|
||||
if cc.MustRevalidate {
|
||||
parts = append(parts, "must-revalidate")
|
||||
}
|
||||
if cc.ProxyRevalidate {
|
||||
parts = append(parts, "proxy-revalidate")
|
||||
}
|
||||
if cc.Immutable {
|
||||
parts = append(parts, "immutable")
|
||||
}
|
||||
if cc.MaxAge > 0 {
|
||||
parts = append(parts, fmt.Sprintf("max-age=%d", cc.MaxAge))
|
||||
}
|
||||
if cc.SMaxAge > 0 {
|
||||
parts = append(parts, fmt.Sprintf("s-maxage=%d", cc.SMaxAge))
|
||||
}
|
||||
if cc.StaleWhileRevalidate > 0 {
|
||||
parts = append(parts, fmt.Sprintf("stale-while-revalidate=%d", cc.StaleWhileRevalidate))
|
||||
}
|
||||
|
||||
result := ""
|
||||
for i, part := range parts {
|
||||
if i > 0 {
|
||||
result += ", "
|
||||
}
|
||||
result += part
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Middleware provides CDN and HTTP caching functionality
|
||||
type Middleware struct {
|
||||
defaultCacheControl CacheControl
|
||||
enableETag bool
|
||||
enableVary bool
|
||||
}
|
||||
|
||||
// Config holds CDN middleware configuration
|
||||
type Config struct {
|
||||
DefaultCacheControl CacheControl
|
||||
EnableETag bool
|
||||
EnableVary bool
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new CDN middleware
|
||||
func NewMiddleware(cfg Config) *Middleware {
|
||||
return &Middleware{
|
||||
defaultCacheControl: cfg.DefaultCacheControl,
|
||||
enableETag: cfg.EnableETag,
|
||||
enableVary: cfg.EnableVary,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler wraps an HTTP handler with CDN caching support
|
||||
func (m *Middleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Wrap response writer to capture response for ETag generation
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
body: nil,
|
||||
}
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// Apply caching headers if successful response
|
||||
if rw.statusCode >= 200 && rw.statusCode < 300 {
|
||||
m.applyCachingHeaders(rw, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// applyCachingHeaders applies appropriate caching headers to the response
|
||||
func (m *Middleware) applyCachingHeaders(w *responseWriter, r *http.Request) {
|
||||
// Set Cache-Control header if not already set
|
||||
if w.Header().Get("Cache-Control") == "" {
|
||||
w.Header().Set("Cache-Control", m.defaultCacheControl.String())
|
||||
}
|
||||
|
||||
// Set Vary header for content negotiation
|
||||
if m.enableVary {
|
||||
m.setVaryHeader(w, r)
|
||||
}
|
||||
|
||||
// Generate and check ETag if enabled
|
||||
if m.enableETag && w.body != nil {
|
||||
m.handleETag(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// setVaryHeader sets the Vary header based on request
|
||||
func (m *Middleware) setVaryHeader(w *responseWriter, r *http.Request) {
|
||||
varies := []string{}
|
||||
|
||||
// Vary on Accept-Encoding for compression
|
||||
if r.Header.Get("Accept-Encoding") != "" {
|
||||
varies = append(varies, "Accept-Encoding")
|
||||
}
|
||||
|
||||
// Vary on Authorization for authenticated requests
|
||||
if r.Header.Get("Authorization") != "" {
|
||||
varies = append(varies, "Authorization")
|
||||
}
|
||||
|
||||
// Vary on Accept for content negotiation
|
||||
if r.Header.Get("Accept") != "" {
|
||||
varies = append(varies, "Accept")
|
||||
}
|
||||
|
||||
if len(varies) > 0 {
|
||||
varyHeader := ""
|
||||
for i, v := range varies {
|
||||
if i > 0 {
|
||||
varyHeader += ", "
|
||||
}
|
||||
varyHeader += v
|
||||
}
|
||||
w.Header().Set("Vary", varyHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// handleETag generates ETag and handles conditional requests
|
||||
func (m *Middleware) handleETag(w *responseWriter, r *http.Request) {
|
||||
// Generate ETag from response body
|
||||
etag := m.generateETag(w.body)
|
||||
w.Header().Set("ETag", etag)
|
||||
|
||||
// Handle conditional requests
|
||||
if ifNoneMatch := r.Header.Get("If-None-Match"); ifNoneMatch != "" {
|
||||
if ifNoneMatch == etag {
|
||||
// ETag matches - return 304 Not Modified
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
w.body = nil // Clear body for 304 response
|
||||
log.Debug().
|
||||
Str("path", r.URL.Path).
|
||||
Str("etag", etag).
|
||||
Msg("ETag match - returning 304 Not Modified")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Handle If-Modified-Since
|
||||
if lastModified := w.Header().Get("Last-Modified"); lastModified != "" {
|
||||
if ifModifiedSince := r.Header.Get("If-Modified-Since"); ifModifiedSince != "" {
|
||||
lastModTime, err := http.ParseTime(lastModified)
|
||||
if err == nil {
|
||||
ifModTime, err := http.ParseTime(ifModifiedSince)
|
||||
if err == nil && !lastModTime.After(ifModTime) {
|
||||
// Not modified - return 304
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
w.body = nil
|
||||
log.Debug().
|
||||
Str("path", r.URL.Path).
|
||||
Time("last_modified", lastModTime).
|
||||
Msg("Not modified - returning 304")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateETag creates an ETag for HTTP caching
|
||||
// NOTE: MD5 is used for content fingerprinting (ETag), not cryptographic security
|
||||
func (m *Middleware) generateETag(body []byte) string {
|
||||
if body == nil {
|
||||
return ""
|
||||
}
|
||||
hash := md5.Sum(body) // #nosec G401 -- MD5 used for ETag, not cryptographic security
|
||||
return `"` + hex.EncodeToString(hash[:]) + `"`
|
||||
}
|
||||
|
||||
// SetLastModified sets the Last-Modified header
|
||||
func SetLastModified(w http.ResponseWriter, t time.Time) {
|
||||
w.Header().Set("Last-Modified", t.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
// SetCacheControl sets a custom Cache-Control header
|
||||
func SetCacheControl(w http.ResponseWriter, cc CacheControl) {
|
||||
w.Header().Set("Cache-Control", cc.String())
|
||||
}
|
||||
|
||||
// SetNoCache sets headers to prevent caching
|
||||
func SetNoCache(w http.ResponseWriter) {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
}
|
||||
|
||||
// SetImmutable sets headers for immutable content (content-addressed files)
|
||||
func SetImmutable(w http.ResponseWriter, maxAge int) {
|
||||
cc := CacheControl{
|
||||
Public: true,
|
||||
MaxAge: maxAge,
|
||||
Immutable: true,
|
||||
}
|
||||
w.Header().Set("Cache-Control", cc.String())
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture response
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
rw.statusCode = statusCode
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
// Capture body for ETag generation
|
||||
if rw.body == nil {
|
||||
rw.body = make([]byte, 0, len(b))
|
||||
}
|
||||
rw.body = append(rw.body, b...)
|
||||
return rw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// HandleRange handles HTTP Range requests for partial content
|
||||
func HandleRange(w http.ResponseWriter, r *http.Request, content io.ReadSeeker, size int64, modTime time.Time) error {
|
||||
// Set Last-Modified header
|
||||
SetLastModified(w, modTime)
|
||||
|
||||
// Check for Range header
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader == "" {
|
||||
// No range request - serve full content
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := io.Copy(w, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse range header (simplified - only handles single range)
|
||||
// Format: bytes=start-end
|
||||
var start, end int64
|
||||
n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end)
|
||||
if err != nil || n != 2 {
|
||||
// Invalid range - serve full content
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := io.Copy(w, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate range
|
||||
if start < 0 || start >= size || end < start || end >= size {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
|
||||
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Seek to start position
|
||||
if _, err := content.Seek(start, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate content length
|
||||
contentLength := end - start + 1
|
||||
|
||||
// Set headers for partial content
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, size))
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// Copy range to response
|
||||
_, err = io.CopyN(w, content, contentLength)
|
||||
return err
|
||||
}
|
||||
|
||||
// DefaultCacheControl returns sensible defaults for different content types
|
||||
func DefaultCacheControl(contentType string, versioned bool) CacheControl {
|
||||
if versioned {
|
||||
// Content-addressed or versioned resources can be cached forever
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 31536000, // 1 year
|
||||
Immutable: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Default caching based on content type
|
||||
switch contentType {
|
||||
case "application/json":
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600, // 1 hour
|
||||
SMaxAge: 7200, // 2 hours for shared caches
|
||||
}
|
||||
case "application/octet-stream", "application/x-gzip", "application/zip":
|
||||
// Binary packages
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 86400, // 1 day
|
||||
SMaxAge: 604800, // 1 week for shared caches
|
||||
}
|
||||
case "text/html":
|
||||
// HTML should revalidate
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 0,
|
||||
MustRevalidate: true,
|
||||
}
|
||||
default:
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600, // 1 hour default
|
||||
SMaxAge: 7200,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is the main configuration struct
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server" json:"server"`
|
||||
Storage StorageConfig `mapstructure:"storage" json:"storage"`
|
||||
Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"`
|
||||
Cache CacheConfig `mapstructure:"cache" json:"cache"`
|
||||
Security SecurityConfig `mapstructure:"security" json:"security"`
|
||||
Auth AuthConfig `mapstructure:"auth" json:"auth"`
|
||||
Network NetworkConfig `mapstructure:"network" json:"network"`
|
||||
Logging LoggingConfig `mapstructure:"logging" json:"logging"`
|
||||
Handlers HandlersConfig `mapstructure:"handlers" json:"handlers"`
|
||||
}
|
||||
|
||||
// ServerConfig contains HTTP server configuration
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"`
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout" json:"idle_timeout"`
|
||||
TLS TLSConfig `mapstructure:"tls" json:"tls"`
|
||||
}
|
||||
|
||||
// TLSConfig contains TLS/HTTPS configuration
|
||||
type TLSConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
CertFile string `mapstructure:"cert_file" json:"cert_file"`
|
||||
KeyFile string `mapstructure:"key_file" json:"key_file"`
|
||||
}
|
||||
|
||||
// StorageConfig contains storage backend configuration
|
||||
type StorageConfig struct {
|
||||
Backend string `mapstructure:"backend" json:"backend"` // filesystem, s3, smb, nfs
|
||||
Path string `mapstructure:"path" json:"path"`
|
||||
Filesystem FilesystemConfig `mapstructure:"filesystem" json:"filesystem"`
|
||||
S3 S3Config `mapstructure:"s3" json:"s3"`
|
||||
SMB SMBConfig `mapstructure:"smb" json:"smb"`
|
||||
Options map[string]interface{} `mapstructure:"options" json:"options"`
|
||||
}
|
||||
|
||||
// FilesystemConfig contains local filesystem storage configuration
|
||||
type FilesystemConfig struct {
|
||||
BasePath string `mapstructure:"base_path" json:"base_path"`
|
||||
}
|
||||
|
||||
// S3Config contains S3-compatible storage configuration
|
||||
type S3Config struct {
|
||||
Endpoint string `mapstructure:"endpoint" json:"endpoint"`
|
||||
Region string `mapstructure:"region" json:"region"`
|
||||
Bucket string `mapstructure:"bucket" json:"bucket"`
|
||||
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"`
|
||||
SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // Don't serialize secrets
|
||||
UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"`
|
||||
}
|
||||
|
||||
// SMBConfig contains SMB/CIFS storage configuration
|
||||
type SMBConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Share string `mapstructure:"share" json:"share"`
|
||||
Username string `mapstructure:"username" json:"username"`
|
||||
Password string `mapstructure:"password" json:"-"` // Don't serialize secrets
|
||||
Domain string `mapstructure:"domain" json:"domain"`
|
||||
}
|
||||
|
||||
// MetadataConfig contains metadata store configuration
|
||||
type MetadataConfig struct {
|
||||
Backend string `mapstructure:"backend" json:"backend"` // sqlite, postgresql, file
|
||||
Connection string `mapstructure:"connection" json:"connection"`
|
||||
SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"`
|
||||
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
|
||||
}
|
||||
|
||||
// SQLiteConfig contains SQLite-specific configuration
|
||||
type SQLiteConfig struct {
|
||||
Path string `mapstructure:"path" json:"path"`
|
||||
WALMode bool `mapstructure:"wal_mode" json:"wal_mode"`
|
||||
}
|
||||
|
||||
// PostgreSQLConfig contains PostgreSQL-specific configuration
|
||||
type PostgreSQLConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
Database string `mapstructure:"database" json:"database"`
|
||||
User string `mapstructure:"user" json:"user"`
|
||||
Password string `mapstructure:"password" json:"-"` // Don't serialize secrets
|
||||
SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache management configuration
|
||||
type CacheConfig struct {
|
||||
DefaultTTL time.Duration `mapstructure:"default_ttl" json:"default_ttl"`
|
||||
CleanupInterval time.Duration `mapstructure:"cleanup_interval" json:"cleanup_interval"`
|
||||
MaxSizeBytes int64 `mapstructure:"max_size_bytes" json:"max_size_bytes"`
|
||||
PerProjectQuota int64 `mapstructure:"per_project_quota" json:"per_project_quota"`
|
||||
TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"` // Per ecosystem
|
||||
}
|
||||
|
||||
// SecurityConfig contains security scanning configuration
|
||||
type SecurityConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"` // Scan packages on first download
|
||||
RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"` // How often to re-scan (e.g., 24h, 168h for weekly)
|
||||
BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"` // none, low, medium, high, critical
|
||||
BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"` // Max vulns per severity before blocking
|
||||
UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"` // Update vulnerability databases on startup
|
||||
AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"` // Packages that bypass security checks (format: "registry/name@version" or "registry/name")
|
||||
IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"` // CVE IDs to ignore globally (e.g., "CVE-2021-23337")
|
||||
Scanners ScannersConfig `mapstructure:"scanners" json:"scanners"`
|
||||
}
|
||||
|
||||
// VulnerabilityThresholds defines max allowed vulnerabilities per severity
|
||||
type VulnerabilityThresholds struct {
|
||||
Critical int `mapstructure:"critical" json:"critical"` // Max critical vulns (0 = block any)
|
||||
High int `mapstructure:"high" json:"high"` // Max high vulns
|
||||
Medium int `mapstructure:"medium" json:"medium"` // Max medium vulns
|
||||
Low int `mapstructure:"low" json:"low"` // Max low vulns (-1 = unlimited)
|
||||
}
|
||||
|
||||
// ScannersConfig contains individual scanner configurations
|
||||
type ScannersConfig struct {
|
||||
Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"`
|
||||
OSV OSVConfig `mapstructure:"osv" json:"osv"`
|
||||
Static StaticConfig `mapstructure:"static" json:"static"`
|
||||
Grype GrypeConfig `mapstructure:"grype" json:"grype"`
|
||||
Govulncheck GovulncheckConfig `mapstructure:"govulncheck" json:"govulncheck"`
|
||||
NpmAudit NpmAuditConfig `mapstructure:"npm_audit" json:"npm_audit"`
|
||||
PipAudit PipAuditConfig `mapstructure:"pip_audit" json:"pip_audit"`
|
||||
GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"`
|
||||
}
|
||||
|
||||
// TrivyConfig contains Trivy scanner configuration
|
||||
type TrivyConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
CacheDB string `mapstructure:"cache_db" json:"cache_db"`
|
||||
}
|
||||
|
||||
// OSVConfig contains OSV scanner configuration
|
||||
type OSVConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
APIURL string `mapstructure:"api_url" json:"api_url"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// StaticConfig contains static analysis configuration
|
||||
type StaticConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
MaxPackageSize int64 `mapstructure:"max_package_size" json:"max_package_size"`
|
||||
CheckChecksums bool `mapstructure:"check_checksums" json:"check_checksums"`
|
||||
BlockSuspicious bool `mapstructure:"block_suspicious" json:"block_suspicious"`
|
||||
AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"`
|
||||
}
|
||||
|
||||
// GrypeConfig contains Grype scanner configuration
|
||||
type GrypeConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// GovulncheckConfig contains govulncheck scanner configuration
|
||||
type GovulncheckConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// NpmAuditConfig contains npm audit scanner configuration
|
||||
type NpmAuditConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// PipAuditConfig contains pip-audit scanner configuration
|
||||
type PipAuditConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
// GHSAConfig contains GitHub Advisory Database scanner configuration
|
||||
type GHSAConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
Token string `mapstructure:"token" json:"-"` // GitHub token for higher rate limits (don't serialize)
|
||||
}
|
||||
|
||||
// AuthConfig contains authentication configuration
|
||||
type AuthConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
KeyExpiration time.Duration `mapstructure:"key_expiration" json:"key_expiration"`
|
||||
BcryptCost int `mapstructure:"bcrypt_cost" json:"bcrypt_cost"`
|
||||
AuditLog bool `mapstructure:"audit_log" json:"audit_log"`
|
||||
}
|
||||
|
||||
// NetworkConfig contains network resilience configuration
|
||||
type NetworkConfig struct {
|
||||
ConnectTimeout time.Duration `mapstructure:"connect_timeout" json:"connect_timeout"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns" json:"max_idle_conns"`
|
||||
MaxConnsPerHost int `mapstructure:"max_conns_per_host" json:"max_conns_per_host"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit" json:"rate_limit"`
|
||||
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker" json:"circuit_breaker"`
|
||||
Retry RetryConfig `mapstructure:"retry" json:"retry"`
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
PerAPIKey int `mapstructure:"per_api_key" json:"per_api_key"`
|
||||
PerIP int `mapstructure:"per_ip" json:"per_ip"`
|
||||
BurstSize int `mapstructure:"burst_size" json:"burst_size"`
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig contains circuit breaker configuration
|
||||
type CircuitBreakerConfig struct {
|
||||
Threshold int `mapstructure:"threshold" json:"threshold"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
ResetInterval time.Duration `mapstructure:"reset_interval" json:"reset_interval"`
|
||||
}
|
||||
|
||||
// RetryConfig contains retry policy configuration
|
||||
type RetryConfig struct {
|
||||
MaxAttempts int `mapstructure:"max_attempts" json:"max_attempts"`
|
||||
InitialBackoff time.Duration `mapstructure:"initial_backoff" json:"initial_backoff"`
|
||||
MaxBackoff time.Duration `mapstructure:"max_backoff" json:"max_backoff"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `mapstructure:"level" json:"level"` // debug, info, warn, error
|
||||
Format string `mapstructure:"format" json:"format"` // json, pretty
|
||||
}
|
||||
|
||||
// HandlersConfig contains package manager handler configurations
|
||||
type HandlersConfig struct {
|
||||
Go GoHandlerConfig `mapstructure:"go" json:"go"`
|
||||
NPM NPMHandlerConfig `mapstructure:"npm" json:"npm"`
|
||||
PyPI PyPIHandlerConfig `mapstructure:"pypi" json:"pypi"`
|
||||
}
|
||||
|
||||
// GoHandlerConfig contains Go proxy configuration
|
||||
type GoHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamProxy string `mapstructure:"upstream_proxy" json:"upstream_proxy"`
|
||||
ChecksumDB string `mapstructure:"checksum_db" json:"checksum_db"`
|
||||
VerifyChecksums bool `mapstructure:"verify_checksums" json:"verify_checksums"`
|
||||
GitCredentialsFile string `mapstructure:"git_credentials_file" json:"git_credentials_file"` // Path to git credentials JSON file
|
||||
}
|
||||
|
||||
// NPMHandlerConfig contains NPM registry configuration
|
||||
type NPMHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamRegistry string `mapstructure:"upstream_registry" json:"upstream_registry"`
|
||||
}
|
||||
|
||||
// PyPIHandlerConfig contains PyPI configuration
|
||||
type PyPIHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamURL string `mapstructure:"upstream_url" json:"upstream_url"`
|
||||
SimpleAPIURL string `mapstructure:"simple_api_url" json:"simple_api_url"`
|
||||
}
|
||||
|
||||
// Default returns a configuration with sensible defaults
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
ReadTimeout: 5 * time.Minute,
|
||||
WriteTimeout: 5 * time.Minute,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
TLS: TLSConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
Backend: "filesystem",
|
||||
Path: "/var/cache/gohoarder",
|
||||
Filesystem: FilesystemConfig{
|
||||
BasePath: "/var/cache/gohoarder",
|
||||
},
|
||||
},
|
||||
Metadata: MetadataConfig{
|
||||
Backend: "sqlite",
|
||||
Connection: "file:gohoarder.db?cache=shared&mode=rwc",
|
||||
SQLite: SQLiteConfig{
|
||||
Path: "gohoarder.db",
|
||||
WALMode: true,
|
||||
},
|
||||
},
|
||||
Cache: CacheConfig{
|
||||
DefaultTTL: 7 * 24 * time.Hour,
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
MaxSizeBytes: 500 * 1024 * 1024 * 1024, // 500GB
|
||||
PerProjectQuota: 50 * 1024 * 1024 * 1024, // 50GB
|
||||
TTLOverrides: map[string]time.Duration{
|
||||
"npm": 7 * 24 * time.Hour,
|
||||
"pip": 7 * 24 * time.Hour,
|
||||
"go": 7 * 24 * time.Hour,
|
||||
},
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Enabled: false,
|
||||
BlockOnSeverity: "high",
|
||||
Scanners: ScannersConfig{
|
||||
Trivy: TrivyConfig{
|
||||
Enabled: false,
|
||||
Timeout: 5 * time.Minute,
|
||||
CacheDB: "/var/lib/trivy",
|
||||
},
|
||||
OSV: OSVConfig{
|
||||
Enabled: false,
|
||||
APIURL: "https://api.osv.dev",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
Static: StaticConfig{
|
||||
Enabled: true,
|
||||
MaxPackageSize: 2 * 1024 * 1024 * 1024, // 2GB
|
||||
CheckChecksums: true,
|
||||
BlockSuspicious: false,
|
||||
},
|
||||
Grype: GrypeConfig{
|
||||
Enabled: false,
|
||||
Timeout: 5 * time.Minute,
|
||||
},
|
||||
Govulncheck: GovulncheckConfig{
|
||||
Enabled: false,
|
||||
Timeout: 5 * time.Minute,
|
||||
},
|
||||
NpmAudit: NpmAuditConfig{
|
||||
Enabled: false,
|
||||
Timeout: 2 * time.Minute,
|
||||
},
|
||||
PipAudit: PipAuditConfig{
|
||||
Enabled: false,
|
||||
Timeout: 2 * time.Minute,
|
||||
},
|
||||
GHSA: GHSAConfig{
|
||||
Enabled: false,
|
||||
Timeout: 30 * time.Second,
|
||||
Token: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
Enabled: true,
|
||||
KeyExpiration: 0, // Never expire
|
||||
BcryptCost: 10,
|
||||
AuditLog: true,
|
||||
},
|
||||
Network: NetworkConfig{
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
ReadTimeout: 5 * time.Minute,
|
||||
WriteTimeout: 5 * time.Minute,
|
||||
MaxIdleConns: 100,
|
||||
MaxConnsPerHost: 10,
|
||||
RateLimit: RateLimitConfig{
|
||||
PerAPIKey: 1000,
|
||||
PerIP: 100,
|
||||
BurstSize: 50,
|
||||
},
|
||||
CircuitBreaker: CircuitBreakerConfig{
|
||||
Threshold: 5,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetInterval: 60 * time.Second,
|
||||
},
|
||||
Retry: RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialBackoff: 1 * time.Second,
|
||||
MaxBackoff: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
},
|
||||
Handlers: HandlersConfig{
|
||||
Go: GoHandlerConfig{
|
||||
Enabled: true,
|
||||
UpstreamProxy: "https://proxy.golang.org",
|
||||
ChecksumDB: "https://sum.golang.org",
|
||||
VerifyChecksums: true,
|
||||
},
|
||||
NPM: NPMHandlerConfig{
|
||||
Enabled: true,
|
||||
UpstreamRegistry: "https://registry.npmjs.org",
|
||||
},
|
||||
PyPI: PyPIHandlerConfig{
|
||||
Enabled: true,
|
||||
UpstreamURL: "https://pypi.org",
|
||||
SimpleAPIURL: "https://pypi.org/simple",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() error {
|
||||
// Validate server
|
||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||
return fmt.Errorf("server.port must be between 1 and 65535, got %d", c.Server.Port)
|
||||
}
|
||||
|
||||
// Validate storage backend
|
||||
validStorageBackends := map[string]bool{"filesystem": true, "s3": true, "smb": true, "nfs": true}
|
||||
if !validStorageBackends[c.Storage.Backend] {
|
||||
return fmt.Errorf("storage.backend must be one of: filesystem, s3, smb, nfs; got %s", c.Storage.Backend)
|
||||
}
|
||||
|
||||
// Validate metadata backend
|
||||
validMetadataBackends := map[string]bool{"sqlite": true, "postgresql": true, "file": true}
|
||||
if !validMetadataBackends[c.Metadata.Backend] {
|
||||
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, file; got %s", c.Metadata.Backend)
|
||||
}
|
||||
|
||||
// Validate cache
|
||||
if c.Cache.DefaultTTL < 0 {
|
||||
return fmt.Errorf("cache.default_ttl cannot be negative")
|
||||
}
|
||||
if c.Cache.MaxSizeBytes < 0 {
|
||||
return fmt.Errorf("cache.max_size_bytes cannot be negative")
|
||||
}
|
||||
|
||||
// Validate security
|
||||
validSeverities := map[string]bool{"none": true, "low": true, "medium": true, "high": true, "critical": true}
|
||||
if !validSeverities[c.Security.BlockOnSeverity] {
|
||||
return fmt.Errorf("security.block_on_severity must be one of: none, low, medium, high, critical; got %s", c.Security.BlockOnSeverity)
|
||||
}
|
||||
|
||||
// Validate logging level
|
||||
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
|
||||
if !validLevels[c.Logging.Level] {
|
||||
return fmt.Errorf("logging.level must be one of: debug, info, warn, error; got %s", c.Logging.Level)
|
||||
}
|
||||
|
||||
// Validate logging format
|
||||
validFormats := map[string]bool{"json": true, "pretty": true}
|
||||
if !validFormats[c.Logging.Format] {
|
||||
return fmt.Errorf("logging.format must be one of: json, pretty; got %s", c.Logging.Format)
|
||||
}
|
||||
|
||||
// Validate auth
|
||||
if c.Auth.BcryptCost < 4 || c.Auth.BcryptCost > 31 {
|
||||
return fmt.Errorf("auth.bcrypt_cost must be between 4 and 31, got %d", c.Auth.BcryptCost)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ConfigTestSuite struct {
|
||||
suite.Suite
|
||||
tempDir string
|
||||
}
|
||||
|
||||
func TestConfigTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConfigTestSuite))
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) SetupTest() {
|
||||
var err error
|
||||
s.tempDir, err = os.MkdirTemp("", "gohoarder-config-test-*")
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TearDownTest() {
|
||||
_ = os.RemoveAll(s.tempDir) // #nosec G104 -- Cleanup
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestDefault() {
|
||||
cfg := Default()
|
||||
s.NotNil(cfg)
|
||||
s.Equal("0.0.0.0", cfg.Server.Host)
|
||||
s.Equal(8080, cfg.Server.Port)
|
||||
s.Equal("filesystem", cfg.Storage.Backend)
|
||||
s.Equal("sqlite", cfg.Metadata.Backend)
|
||||
s.NoError(cfg.Validate())
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestValidate() {
|
||||
tests := []struct {
|
||||
name string
|
||||
modify func(*Config)
|
||||
expectError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid_config",
|
||||
modify: func(c *Config) {},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid_port_too_low",
|
||||
modify: func(c *Config) {
|
||||
c.Server.Port = 0
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "port must be between",
|
||||
},
|
||||
{
|
||||
name: "invalid_port_too_high",
|
||||
modify: func(c *Config) {
|
||||
c.Server.Port = 70000
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "port must be between",
|
||||
},
|
||||
{
|
||||
name: "invalid_storage_backend",
|
||||
modify: func(c *Config) {
|
||||
c.Storage.Backend = "invalid"
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "storage.backend must be one of",
|
||||
},
|
||||
{
|
||||
name: "invalid_metadata_backend",
|
||||
modify: func(c *Config) {
|
||||
c.Metadata.Backend = "mongodb"
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "metadata.backend must be one of",
|
||||
},
|
||||
{
|
||||
name: "negative_ttl",
|
||||
modify: func(c *Config) {
|
||||
c.Cache.DefaultTTL = -1 * time.Hour
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "negative_cache_size",
|
||||
modify: func(c *Config) {
|
||||
c.Cache.MaxSizeBytes = -100
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "invalid_severity",
|
||||
modify: func(c *Config) {
|
||||
c.Security.BlockOnSeverity = "super-high"
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "block_on_severity must be one of",
|
||||
},
|
||||
{
|
||||
name: "invalid_log_level",
|
||||
modify: func(c *Config) {
|
||||
c.Logging.Level = "verbose"
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "logging.level must be one of",
|
||||
},
|
||||
{
|
||||
name: "invalid_log_format",
|
||||
modify: func(c *Config) {
|
||||
c.Logging.Format = "xml"
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "logging.format must be one of",
|
||||
},
|
||||
{
|
||||
name: "invalid_bcrypt_cost_too_low",
|
||||
modify: func(c *Config) {
|
||||
c.Auth.BcryptCost = 3
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "bcrypt_cost must be between",
|
||||
},
|
||||
{
|
||||
name: "invalid_bcrypt_cost_too_high",
|
||||
modify: func(c *Config) {
|
||||
c.Auth.BcryptCost = 32
|
||||
},
|
||||
expectError: true,
|
||||
errorSubstr: "bcrypt_cost must be between",
|
||||
},
|
||||
{
|
||||
name: "valid_s3_backend",
|
||||
modify: func(c *Config) {
|
||||
c.Storage.Backend = "s3"
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid_postgresql_backend",
|
||||
modify: func(c *Config) {
|
||||
c.Metadata.Backend = "postgresql"
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
cfg := Default()
|
||||
tt.modify(cfg)
|
||||
err := cfg.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
if tt.errorSubstr != "" {
|
||||
s.Contains(err.Error(), tt.errorSubstr)
|
||||
}
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestLoad() {
|
||||
tests := []struct {
|
||||
name string
|
||||
configYAML string
|
||||
envVars map[string]string
|
||||
expectError bool
|
||||
validate func(*Config)
|
||||
}{
|
||||
{
|
||||
name: "valid_yaml_config",
|
||||
configYAML: `
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9000
|
||||
storage:
|
||||
backend: filesystem
|
||||
path: /custom/path
|
||||
logging:
|
||||
level: debug
|
||||
format: pretty
|
||||
`,
|
||||
expectError: false,
|
||||
validate: func(cfg *Config) {
|
||||
s.Equal("127.0.0.1", cfg.Server.Host)
|
||||
s.Equal(9000, cfg.Server.Port)
|
||||
s.Equal("/custom/path", cfg.Storage.Path)
|
||||
s.Equal("debug", cfg.Logging.Level)
|
||||
s.Equal("pretty", cfg.Logging.Format)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "env_var_override",
|
||||
configYAML: `
|
||||
server:
|
||||
port: 8080
|
||||
`,
|
||||
envVars: map[string]string{
|
||||
"GOHOARDER_SERVER_PORT": "9090",
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(cfg *Config) {
|
||||
s.Equal(9090, cfg.Server.Port)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid_yaml",
|
||||
configYAML: `
|
||||
server: [invalid
|
||||
`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "validation_failure",
|
||||
configYAML: `
|
||||
server:
|
||||
port: 100000
|
||||
`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "complete_config",
|
||||
configYAML: `
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
read_timeout: 300s
|
||||
write_timeout: 300s
|
||||
storage:
|
||||
backend: s3
|
||||
s3:
|
||||
endpoint: s3.amazonaws.com
|
||||
region: us-east-1
|
||||
bucket: my-cache
|
||||
access_key_id: AKIAIOSFODNN7EXAMPLE
|
||||
secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
|
||||
metadata:
|
||||
backend: postgresql
|
||||
postgresql:
|
||||
host: localhost
|
||||
port: 5432
|
||||
database: gohoarder
|
||||
user: postgres
|
||||
password: secret
|
||||
ssl_mode: require
|
||||
cache:
|
||||
default_ttl: 168h
|
||||
max_size_bytes: 536870912000
|
||||
security:
|
||||
enabled: true
|
||||
block_on_severity: high
|
||||
scanners:
|
||||
trivy:
|
||||
enabled: true
|
||||
timeout: 300s
|
||||
auth:
|
||||
enabled: true
|
||||
bcrypt_cost: 12
|
||||
`,
|
||||
expectError: false,
|
||||
validate: func(cfg *Config) {
|
||||
s.Equal("s3", cfg.Storage.Backend)
|
||||
s.Equal("s3.amazonaws.com", cfg.Storage.S3.Endpoint)
|
||||
s.Equal("postgresql", cfg.Metadata.Backend)
|
||||
s.Equal("localhost", cfg.Metadata.PostgreSQL.Host)
|
||||
s.True(cfg.Security.Enabled)
|
||||
s.Equal(12, cfg.Auth.BcryptCost)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// Write config file
|
||||
configPath := filepath.Join(s.tempDir, "config.yaml")
|
||||
err := os.WriteFile(configPath, []byte(tt.configYAML), 0644)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Set environment variables
|
||||
for k, v := range tt.envVars {
|
||||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
|
||||
// Load config
|
||||
cfg, err := Load(configPath)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(cfg)
|
||||
if tt.validate != nil {
|
||||
tt.validate(cfg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestLoadMissingFile() {
|
||||
// Should return error when file explicitly specified but not found
|
||||
cfg, err := Load("/nonexistent/path/to/config.yaml")
|
||||
s.Error(err)
|
||||
s.Nil(cfg)
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestLoadWithDefaults() {
|
||||
// Invalid config path should return defaults
|
||||
cfg := LoadWithDefaults("/invalid/path/config.yaml")
|
||||
s.NotNil(cfg)
|
||||
s.Equal(8080, cfg.Server.Port)
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefault(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = Default()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidate(b *testing.B) {
|
||||
cfg := Default()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cfg.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
// Table-driven edge cases
|
||||
func TestConfigEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "minimal_config",
|
||||
config: &Config{Server: ServerConfig{Port: 8080}, Storage: StorageConfig{Backend: "filesystem"}, Metadata: MetadataConfig{Backend: "sqlite"}, Logging: LoggingConfig{Level: "info", Format: "json"}, Security: SecurityConfig{BlockOnSeverity: "high"}, Auth: AuthConfig{BcryptCost: 10}},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "zero_ttl",
|
||||
config: func() *Config { c := Default(); c.Cache.DefaultTTL = 0; return c }(),
|
||||
valid: true, // Zero is valid (no caching)
|
||||
},
|
||||
{
|
||||
name: "max_bcrypt_cost",
|
||||
config: func() *Config { c := Default(); c.Auth.BcryptCost = 31; return c }(),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "min_bcrypt_cost",
|
||||
config: func() *Config { c := Default(); c.Auth.BcryptCost = 4; return c }(),
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Load loads configuration from file and environment variables
|
||||
func Load(configPath string) (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// Set config file if provided
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
} else {
|
||||
// Look for config.yaml in current directory and /etc/gohoarder
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("/etc/gohoarder")
|
||||
v.AddConfigPath("$HOME/.gohoarder")
|
||||
}
|
||||
|
||||
// Set environment variable prefix
|
||||
v.SetEnvPrefix("GOHOARDER")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
// Read config file
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
// If no config file found, use defaults
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start with defaults
|
||||
cfg := Default()
|
||||
|
||||
// Unmarshal into config struct
|
||||
if err := v.Unmarshal(cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("config validation failed: %w", err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadWithDefaults loads configuration or returns defaults on error
|
||||
func LoadWithDefaults(configPath string) *Config {
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
return Default()
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package errors
|
||||
|
||||
// Error codes following consistent naming convention
|
||||
const (
|
||||
// Client errors (4xx)
|
||||
ErrCodeBadRequest = "BAD_REQUEST"
|
||||
ErrCodeUnauthorized = "UNAUTHORIZED"
|
||||
ErrCodeForbidden = "FORBIDDEN"
|
||||
ErrCodeNotFound = "NOT_FOUND"
|
||||
ErrCodeRateLimited = "RATE_LIMITED"
|
||||
ErrCodePayloadTooLarge = "PAYLOAD_TOO_LARGE"
|
||||
ErrCodeInvalidAPIKey = "INVALID_API_KEY" // #nosec G101 -- Not a credential, just an error code constant
|
||||
ErrCodeQuotaExceeded = "QUOTA_EXCEEDED"
|
||||
ErrCodeConflict = "CONFLICT"
|
||||
ErrCodeInvalidConfig = "INVALID_CONFIG"
|
||||
|
||||
// Package-specific errors
|
||||
ErrCodePackageNotFound = "PACKAGE_NOT_FOUND"
|
||||
ErrCodeVersionNotFound = "VERSION_NOT_FOUND"
|
||||
ErrCodeChecksumMismatch = "CHECKSUM_MISMATCH"
|
||||
ErrCodeCorruptPackage = "CORRUPT_PACKAGE"
|
||||
ErrCodeSecurityBlocked = "SECURITY_BLOCKED"
|
||||
ErrCodeSecurityViolation = "SECURITY_VIOLATION" // Package has vulnerabilities exceeding thresholds
|
||||
ErrCodeUpstreamError = "UPSTREAM_ERROR"
|
||||
|
||||
// Server errors (5xx)
|
||||
ErrCodeInternalServer = "INTERNAL_SERVER_ERROR"
|
||||
ErrCodeStorageFailure = "STORAGE_FAILURE"
|
||||
ErrCodeUpstreamFailure = "UPSTREAM_FAILURE"
|
||||
ErrCodeDatabaseFailure = "DATABASE_FAILURE"
|
||||
ErrCodeServiceUnavailable = "SERVICE_UNAVAILABLE"
|
||||
ErrCodeCircuitOpen = "CIRCUIT_OPEN"
|
||||
)
|
||||
|
||||
// HTTPStatusCode maps error codes to HTTP status codes
|
||||
var HTTPStatusCode = map[string]int{
|
||||
ErrCodeBadRequest: 400,
|
||||
ErrCodeUnauthorized: 401,
|
||||
ErrCodeForbidden: 403,
|
||||
ErrCodeNotFound: 404,
|
||||
ErrCodeConflict: 409,
|
||||
ErrCodeRateLimited: 429,
|
||||
ErrCodePayloadTooLarge: 413,
|
||||
ErrCodeInvalidAPIKey: 401,
|
||||
ErrCodeQuotaExceeded: 429,
|
||||
ErrCodeInvalidConfig: 400,
|
||||
ErrCodePackageNotFound: 404,
|
||||
ErrCodeVersionNotFound: 404,
|
||||
ErrCodeChecksumMismatch: 422,
|
||||
ErrCodeCorruptPackage: 422,
|
||||
ErrCodeSecurityBlocked: 403,
|
||||
ErrCodeSecurityViolation: 426, // Upgrade Required
|
||||
ErrCodeUpstreamError: 502,
|
||||
ErrCodeInternalServer: 500,
|
||||
ErrCodeStorageFailure: 500,
|
||||
ErrCodeUpstreamFailure: 502,
|
||||
ErrCodeDatabaseFailure: 500,
|
||||
ErrCodeServiceUnavailable: 503,
|
||||
ErrCodeCircuitOpen: 503,
|
||||
}
|
||||
|
||||
// GetHTTPStatus returns the HTTP status code for an error code
|
||||
func GetHTTPStatus(code string) int {
|
||||
if status, ok := HTTPStatusCode[code]; ok {
|
||||
return status
|
||||
}
|
||||
return 500 // Default to internal server error
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error represents a structured error with code and details
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Trace []string `json:"trace,omitempty"`
|
||||
Cause error `json:"-"` // Internal cause, not serialized
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *Error) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the cause for errors.Is/As support
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// New creates a new error with the given code and message
|
||||
func New(code, message string) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// Newf creates a new error with formatted message
|
||||
func Newf(code, format string, args ...interface{}) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithDetails adds details to the error
|
||||
func (e *Error) WithDetails(details interface{}) *Error {
|
||||
e.Details = details
|
||||
return e
|
||||
}
|
||||
|
||||
// WithTrace adds stack trace to the error
|
||||
func (e *Error) WithTrace(trace []string) *Error {
|
||||
e.Trace = trace
|
||||
return e
|
||||
}
|
||||
|
||||
// WithCause adds an underlying cause to the error
|
||||
func (e *Error) WithCause(cause error) *Error {
|
||||
e.Cause = cause
|
||||
return e
|
||||
}
|
||||
|
||||
// Wrap wraps an existing error with a new code and message
|
||||
func Wrap(err error, code, message string) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapf wraps an existing error with formatted message
|
||||
func Wrapf(err error, code, format string, args ...interface{}) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Common error constructors
|
||||
func BadRequest(message string) *Error {
|
||||
return New(ErrCodeBadRequest, message)
|
||||
}
|
||||
|
||||
func Unauthorized(message string) *Error {
|
||||
return New(ErrCodeUnauthorized, message)
|
||||
}
|
||||
|
||||
func Forbidden(message string) *Error {
|
||||
return New(ErrCodeForbidden, message)
|
||||
}
|
||||
|
||||
func NotFound(message string) *Error {
|
||||
return New(ErrCodeNotFound, message)
|
||||
}
|
||||
|
||||
func InternalServer(message string) *Error {
|
||||
return New(ErrCodeInternalServer, message)
|
||||
}
|
||||
|
||||
func PackageNotFound(name, version string) *Error {
|
||||
return New(ErrCodePackageNotFound, fmt.Sprintf("Package %s@%s not found", name, version)).
|
||||
WithDetails(map[string]string{
|
||||
"package": name,
|
||||
"version": version,
|
||||
})
|
||||
}
|
||||
|
||||
func QuotaExceeded(limit int64) *Error {
|
||||
return New(ErrCodeQuotaExceeded, "Storage quota exceeded").
|
||||
WithDetails(map[string]interface{}{
|
||||
"limit_bytes": limit,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ErrorsTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestErrorsTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ErrorsTestSuite))
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestNew() {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "simple_error",
|
||||
code: ErrCodeNotFound,
|
||||
message: "Resource not found",
|
||||
},
|
||||
{
|
||||
name: "empty_message",
|
||||
code: ErrCodeBadRequest,
|
||||
message: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := New(tt.code, tt.message)
|
||||
s.Equal(tt.code, err.Code)
|
||||
s.Equal(tt.message, err.Message)
|
||||
s.Nil(err.Details)
|
||||
s.Nil(err.Trace)
|
||||
s.Nil(err.Cause)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestNewf() {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
format string
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "formatted_message",
|
||||
code: ErrCodePackageNotFound,
|
||||
format: "Package %s@%s not found",
|
||||
args: []interface{}{"react", "18.2.0"},
|
||||
expected: "Package react@18.2.0 not found",
|
||||
},
|
||||
{
|
||||
name: "no_args",
|
||||
code: ErrCodeInternalServer,
|
||||
format: "Internal error",
|
||||
args: []interface{}{},
|
||||
expected: "Internal error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := Newf(tt.code, tt.format, tt.args...)
|
||||
s.Equal(tt.code, err.Code)
|
||||
s.Equal(tt.expected, err.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithDetails() {
|
||||
tests := []struct {
|
||||
name string
|
||||
details interface{}
|
||||
}{
|
||||
{
|
||||
name: "map_details",
|
||||
details: map[string]string{"key": "value"},
|
||||
},
|
||||
{
|
||||
name: "string_details",
|
||||
details: "some details",
|
||||
},
|
||||
{
|
||||
name: "nil_details",
|
||||
details: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := New(ErrCodeBadRequest, "test").WithDetails(tt.details)
|
||||
s.Equal(tt.details, err.Details)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithTrace() {
|
||||
trace := []string{"file1.go:10", "file2.go:20"}
|
||||
err := New(ErrCodeInternalServer, "test").WithTrace(trace)
|
||||
s.Equal(trace, err.Trace)
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithCause() {
|
||||
cause := errors.New("underlying error")
|
||||
err := New(ErrCodeStorageFailure, "test").WithCause(cause)
|
||||
s.Equal(cause, err.Cause)
|
||||
s.Contains(err.Error(), "underlying error")
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWrap() {
|
||||
cause := errors.New("original error")
|
||||
wrapped := Wrap(cause, ErrCodeDatabaseFailure, "database connection failed")
|
||||
|
||||
s.Equal(ErrCodeDatabaseFailure, wrapped.Code)
|
||||
s.Equal("database connection failed", wrapped.Message)
|
||||
s.Equal(cause, wrapped.Cause)
|
||||
s.True(errors.Is(wrapped, cause))
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWrapf() {
|
||||
cause := errors.New("connection refused")
|
||||
wrapped := Wrapf(cause, ErrCodeUpstreamFailure, "failed to connect to %s", "registry.npmjs.org")
|
||||
|
||||
s.Equal(ErrCodeUpstreamFailure, wrapped.Code)
|
||||
s.Equal("failed to connect to registry.npmjs.org", wrapped.Message)
|
||||
s.Equal(cause, wrapped.Cause)
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestErrorString() {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *Error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "error_without_cause",
|
||||
err: New(ErrCodeNotFound, "not found"),
|
||||
expected: "NOT_FOUND: not found",
|
||||
},
|
||||
{
|
||||
name: "error_with_cause",
|
||||
err: Wrap(errors.New("io error"), ErrCodeStorageFailure, "storage failed"),
|
||||
expected: "STORAGE_FAILURE: storage failed (caused by: io error)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
s.Equal(tt.expected, tt.err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestCommonConstructors() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() *Error
|
||||
wantCode string
|
||||
}{
|
||||
{
|
||||
name: "bad_request",
|
||||
fn: func() *Error { return BadRequest("invalid input") },
|
||||
wantCode: ErrCodeBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fn: func() *Error { return Unauthorized("invalid token") },
|
||||
wantCode: ErrCodeUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden",
|
||||
fn: func() *Error { return Forbidden("access denied") },
|
||||
wantCode: ErrCodeForbidden,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
fn: func() *Error { return NotFound("resource missing") },
|
||||
wantCode: ErrCodeNotFound,
|
||||
},
|
||||
{
|
||||
name: "internal_server",
|
||||
fn: func() *Error { return InternalServer("server error") },
|
||||
wantCode: ErrCodeInternalServer,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := tt.fn()
|
||||
s.Equal(tt.wantCode, err.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestPackageNotFound() {
|
||||
err := PackageNotFound("lodash", "4.17.21")
|
||||
s.Equal(ErrCodePackageNotFound, err.Code)
|
||||
s.Equal("Package lodash@4.17.21 not found", err.Message)
|
||||
s.NotNil(err.Details)
|
||||
|
||||
details, ok := err.Details.(map[string]string)
|
||||
s.True(ok)
|
||||
s.Equal("lodash", details["package"])
|
||||
s.Equal("4.17.21", details["version"])
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestQuotaExceeded() {
|
||||
limit := int64(1000000)
|
||||
err := QuotaExceeded(limit)
|
||||
s.Equal(ErrCodeQuotaExceeded, err.Code)
|
||||
s.NotNil(err.Details)
|
||||
|
||||
details, ok := err.Details.(map[string]interface{})
|
||||
s.True(ok)
|
||||
s.Equal(limit, details["limit_bytes"])
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestUnwrap() {
|
||||
cause := errors.New("root cause")
|
||||
wrapped := Wrap(cause, ErrCodeDatabaseFailure, "db error")
|
||||
|
||||
unwrapped := wrapped.Unwrap()
|
||||
s.Equal(cause, unwrapped)
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkNewError(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = New(ErrCodeNotFound, "test error")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewErrorWithDetails(b *testing.B) {
|
||||
details := map[string]string{"key": "value"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = New(ErrCodeNotFound, "test error").WithDetails(details)
|
||||
}
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
func (s *ErrorsTestSuite) TestEdgeCases() {
|
||||
s.Run("nil_error_wrap", func() {
|
||||
wrapped := Wrap(nil, ErrCodeInternalServer, "test")
|
||||
s.Nil(wrapped.Cause)
|
||||
})
|
||||
|
||||
s.Run("chained_wrapping", func() {
|
||||
err1 := errors.New("base")
|
||||
err2 := Wrap(err1, ErrCodeStorageFailure, "storage")
|
||||
err3 := Wrap(err2, ErrCodeInternalServer, "internal")
|
||||
|
||||
s.True(errors.Is(err3, err2))
|
||||
s.True(errors.Is(err3, err1))
|
||||
})
|
||||
|
||||
s.Run("large_details", func() {
|
||||
largeDetails := make(map[string]string)
|
||||
for i := 0; i < 1000; i++ {
|
||||
largeDetails[string(rune(i))] = "value"
|
||||
}
|
||||
err := New(ErrCodeBadRequest, "test").WithDetails(largeDetails)
|
||||
s.Equal(largeDetails, err.Details)
|
||||
})
|
||||
}
|
||||
|
||||
// Table-driven test for error codes
|
||||
func TestGetHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
code string
|
||||
expectedStatus int
|
||||
}{
|
||||
{ErrCodeBadRequest, 400},
|
||||
{ErrCodeUnauthorized, 401},
|
||||
{ErrCodeForbidden, 403},
|
||||
{ErrCodeNotFound, 404},
|
||||
{ErrCodeConflict, 409},
|
||||
{ErrCodePayloadTooLarge, 413},
|
||||
{ErrCodeChecksumMismatch, 422},
|
||||
{ErrCodeRateLimited, 429},
|
||||
{ErrCodeInternalServer, 500},
|
||||
{ErrCodeDatabaseFailure, 500},
|
||||
{ErrCodeUpstreamFailure, 502},
|
||||
{ErrCodeServiceUnavailable, 503},
|
||||
{"UNKNOWN_CODE", 500}, // Default
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.code, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expectedStatus, GetHTTPStatus(tt.code))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// Response is the standard API response envelope
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error *ErrorResponse `json:"error,omitempty"`
|
||||
Metadata *ResponseMeta `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse contains error details
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Trace []string `json:"trace,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseMeta contains request metadata
|
||||
type ResponseMeta struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// WriteJSON writes a success response as JSON
|
||||
func WriteJSON(w http.ResponseWriter, statusCode int, data interface{}, meta *ResponseMeta) {
|
||||
response := Response{
|
||||
Success: statusCode < 400,
|
||||
Data: data,
|
||||
Metadata: meta,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
// Fallback to simple error response
|
||||
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode response"}}`, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteError writes an error response as JSON
|
||||
func WriteError(w http.ResponseWriter, statusCode int, err *Error, meta *ResponseMeta) {
|
||||
errResp := &ErrorResponse{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
Details: err.Details,
|
||||
Trace: err.Trace,
|
||||
}
|
||||
|
||||
response := Response{
|
||||
Success: false,
|
||||
Error: errResp,
|
||||
Metadata: meta,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if encErr := json.NewEncoder(w).Encode(response); encErr != nil {
|
||||
// Fallback to simple error response
|
||||
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode error response"}}`, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorSimple writes an error without metadata
|
||||
func WriteErrorSimple(w http.ResponseWriter, err *Error) {
|
||||
statusCode := GetHTTPStatus(err.Code)
|
||||
meta := &ResponseMeta{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
WriteError(w, statusCode, err, meta)
|
||||
}
|
||||
|
||||
// WriteJSONSimple writes a success response without metadata
|
||||
func WriteJSONSimple(w http.ResponseWriter, statusCode int, data interface{}) {
|
||||
meta := &ResponseMeta{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
WriteJSON(w, statusCode, data, meta)
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/lukaszraczylo/gohoarder/internal/version"
|
||||
)
|
||||
|
||||
// Status represents component health status
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusHealthy Status = "healthy"
|
||||
StatusUnhealthy Status = "unhealthy"
|
||||
StatusDegraded Status = "degraded"
|
||||
)
|
||||
|
||||
// Check represents a single health check
|
||||
type Check struct {
|
||||
Name string `json:"name"`
|
||||
Status Status `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Fn func(context.Context) (Status, string) `json:"-"`
|
||||
}
|
||||
|
||||
// Response is the health check response
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data *HealthData `json:"data,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// HealthData contains health check data
|
||||
type HealthData struct {
|
||||
Status Status `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
Components map[string]*Component `json:"components"`
|
||||
}
|
||||
|
||||
// Component represents a system component
|
||||
type Component struct {
|
||||
Status Status `json:"status"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata contains response metadata
|
||||
type Metadata struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Checker manages health checks
|
||||
type Checker struct {
|
||||
checks []*Check
|
||||
startTime time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates a new health checker
|
||||
func New() *Checker {
|
||||
return &Checker{
|
||||
checks: make([]*Check, 0),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCheck adds a health check
|
||||
func (c *Checker) AddCheck(name string, fn func(context.Context) (Status, string)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.checks = append(c.checks, &Check{
|
||||
Name: name,
|
||||
Fn: fn,
|
||||
})
|
||||
}
|
||||
|
||||
// RunChecks runs all health checks
|
||||
func (c *Checker) RunChecks(ctx context.Context) *HealthData {
|
||||
c.mu.RLock()
|
||||
checks := make([]*Check, len(c.checks))
|
||||
copy(checks, c.checks)
|
||||
c.mu.RUnlock()
|
||||
|
||||
components := make(map[string]*Component)
|
||||
overallStatus := StatusHealthy
|
||||
|
||||
for _, check := range checks {
|
||||
status, errMsg := check.Fn(ctx)
|
||||
components[check.Name] = &Component{
|
||||
Status: status,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// Determine overall status
|
||||
if status == StatusUnhealthy {
|
||||
overallStatus = StatusUnhealthy
|
||||
} else if status == StatusDegraded && overallStatus == StatusHealthy {
|
||||
overallStatus = StatusDegraded
|
||||
}
|
||||
}
|
||||
|
||||
return &HealthData{
|
||||
Status: overallStatus,
|
||||
Version: version.Version,
|
||||
Uptime: time.Since(c.startTime).String(),
|
||||
Components: components,
|
||||
}
|
||||
}
|
||||
|
||||
// HealthHandler returns an HTTP handler for health checks
|
||||
func (c *Checker) HealthHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
healthData := c.RunChecks(ctx)
|
||||
|
||||
response := Response{
|
||||
Success: healthData.Status == StatusHealthy,
|
||||
Data: healthData,
|
||||
Metadata: &Metadata{
|
||||
RequestID: r.Header.Get("X-Request-ID"),
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if healthData.Status == StatusUnhealthy {
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
} else if healthData.Status == StatusDegraded {
|
||||
statusCode = http.StatusOK // 200 but degraded
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(response) // #nosec G104 -- JSON response write
|
||||
}
|
||||
}
|
||||
|
||||
// ReadyHandler returns an HTTP handler for readiness checks
|
||||
func (c *Checker) ReadyHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
healthData := c.RunChecks(ctx)
|
||||
|
||||
ready := healthData.Status != StatusUnhealthy
|
||||
|
||||
response := Response{
|
||||
Success: ready,
|
||||
Data: &HealthData{
|
||||
Status: healthData.Status,
|
||||
Components: healthData.Components,
|
||||
},
|
||||
Metadata: &Metadata{
|
||||
RequestID: r.Header.Get("X-Request-ID"),
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if !ready {
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(response) // #nosec G104 -- JSON response write
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLockNotAcquired = errors.New("lock not acquired")
|
||||
ErrLockNotHeld = errors.New("lock not held by this instance")
|
||||
ErrInvalidTTL = errors.New("invalid TTL: must be positive")
|
||||
)
|
||||
|
||||
// Lock represents a distributed lock
|
||||
type Lock struct {
|
||||
client *redis.Client
|
||||
key string
|
||||
value string
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Manager manages distributed locks using Redis
|
||||
type Manager struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// Config holds Redis connection configuration
|
||||
type Config struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
// NewManager creates a new lock manager
|
||||
func NewManager(cfg Config) (*Manager, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("addr", cfg.Addr).
|
||||
Int("db", cfg.DB).
|
||||
Msg("Connected to Redis for distributed locking")
|
||||
|
||||
return &Manager{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Acquire attempts to acquire a lock with the given key and TTL
|
||||
// Returns a Lock instance if successful, or an error if the lock is already held
|
||||
func (m *Manager) Acquire(ctx context.Context, key string, ttl time.Duration) (*Lock, error) {
|
||||
if ttl <= 0 {
|
||||
return nil, ErrInvalidTTL
|
||||
}
|
||||
|
||||
// Generate unique value for this lock instance
|
||||
value, err := generateLockValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to acquire lock using SET NX (set if not exists)
|
||||
success, err := m.client.SetNX(ctx, key, value, ttl).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to acquire lock")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !success {
|
||||
log.Debug().
|
||||
Str("key", key).
|
||||
Msg("Lock already held by another instance")
|
||||
return nil, ErrLockNotAcquired
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", key).
|
||||
Dur("ttl", ttl).
|
||||
Msg("Lock acquired successfully")
|
||||
|
||||
return &Lock{
|
||||
client: m.client,
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a lock, retrying for the specified duration
|
||||
// Returns a Lock instance if successful within the timeout, or an error
|
||||
func (m *Manager) TryAcquire(ctx context.Context, key string, ttl, timeout time.Duration) (*Lock, error) {
|
||||
if ttl <= 0 {
|
||||
return nil, ErrInvalidTTL
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
lock, err := m.Acquire(ctx, key, ttl)
|
||||
if err == nil {
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
if err != ErrLockNotAcquired {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, ErrLockNotAcquired
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release releases the lock
|
||||
// Returns an error if the lock is not held by this instance
|
||||
func (l *Lock) Release(ctx context.Context) error {
|
||||
// Use Lua script to ensure atomic check-and-delete
|
||||
// Only delete if the value matches (ensures we own the lock)
|
||||
script := redis.NewScript(`
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`)
|
||||
|
||||
result, err := script.Run(ctx, l.client, []string{l.key}, l.value).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", l.key).
|
||||
Msg("Failed to release lock")
|
||||
return err
|
||||
}
|
||||
|
||||
// Result of 0 means the lock was not deleted (not owned by us)
|
||||
if result.(int64) == 0 {
|
||||
log.Warn().
|
||||
Str("key", l.key).
|
||||
Msg("Attempted to release lock not held by this instance")
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", l.key).
|
||||
Msg("Lock released successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extend extends the lock TTL
|
||||
// Returns an error if the lock is not held by this instance
|
||||
func (l *Lock) Extend(ctx context.Context, additionalTTL time.Duration) error {
|
||||
// Use Lua script to ensure atomic check-and-extend
|
||||
script := redis.NewScript(`
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`)
|
||||
|
||||
newTTL := l.ttl + additionalTTL
|
||||
result, err := script.Run(ctx, l.client, []string{l.key}, l.value, int(newTTL.Seconds())).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", l.key).
|
||||
Msg("Failed to extend lock")
|
||||
return err
|
||||
}
|
||||
|
||||
if result.(int64) == 0 {
|
||||
log.Warn().
|
||||
Str("key", l.key).
|
||||
Msg("Attempted to extend lock not held by this instance")
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
l.ttl = newTTL
|
||||
log.Debug().
|
||||
Str("key", l.key).
|
||||
Dur("new_ttl", newTTL).
|
||||
Msg("Lock TTL extended")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHeld checks if the lock is still held by this instance
|
||||
func (l *Lock) IsHeld(ctx context.Context) bool {
|
||||
value, err := l.client.Get(ctx, l.key).Result()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value == l.value
|
||||
}
|
||||
|
||||
// Close closes the lock manager and its Redis connection
|
||||
func (m *Manager) Close() error {
|
||||
return m.client.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
|
||||
// generateLockValue generates a cryptographically random lock value
|
||||
func generateLockValue() (string, error) {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// WithLock executes a function while holding a distributed lock
|
||||
// The lock is automatically released when the function returns
|
||||
func (m *Manager) WithLock(ctx context.Context, key string, ttl time.Duration, fn func(context.Context) error) error {
|
||||
lock, err := m.Acquire(ctx, key, ttl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Release(context.Background()); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to release lock in defer")
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// WithRetryLock executes a function while holding a distributed lock
|
||||
// It retries acquisition for the specified timeout duration
|
||||
func (m *Manager) WithRetryLock(ctx context.Context, key string, ttl, timeout time.Duration, fn func(context.Context) error) error {
|
||||
lock, err := m.TryAcquire(ctx, key, ttl, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Release(context.Background()); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to release lock in defer")
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx)
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Config contains logger configuration
|
||||
type Config struct {
|
||||
Level string // debug, info, warn, error
|
||||
Format string // json, pretty
|
||||
}
|
||||
|
||||
// Init initializes the global logger
|
||||
func Init(cfg Config) error {
|
||||
// Set log level
|
||||
level, err := zerolog.ParseLevel(cfg.Level)
|
||||
if err != nil {
|
||||
level = zerolog.InfoLevel
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
|
||||
// Set time format
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs
|
||||
|
||||
// Set format
|
||||
if cfg.Format == "pretty" {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05.000"})
|
||||
} else {
|
||||
// JSON format (default for production)
|
||||
log.Logger = zerolog.New(os.Stdout).With().Timestamp().Logger()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the global logger
|
||||
func Get() *zerolog.Logger {
|
||||
return &log.Logger
|
||||
}
|
||||
|
||||
// WithFields returns a logger with additional fields
|
||||
func WithFields(fields map[string]interface{}) *zerolog.Logger {
|
||||
logger := log.Logger
|
||||
for k, v := range fields {
|
||||
logger = logger.With().Interface(k, v).Logger()
|
||||
}
|
||||
return &logger
|
||||
}
|
||||
|
||||
// WithRequestID returns a logger with request ID
|
||||
func WithRequestID(requestID string) *zerolog.Logger {
|
||||
logger := log.With().Str("request_id", requestID).Logger()
|
||||
return &logger
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written int64
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
n, err := rw.ResponseWriter.Write(b)
|
||||
rw.written += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Middleware is HTTP middleware for request logging
|
||||
func Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Generate request ID
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = "req_" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
// Wrap response writer
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
// Set request ID in response header
|
||||
rw.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// Log request
|
||||
duration := time.Since(start)
|
||||
log.Info().
|
||||
Str("request_id", requestID).
|
||||
Str("method", r.Method).
|
||||
Str("path", r.URL.Path).
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Str("user_agent", r.UserAgent()).
|
||||
Int("status", rw.statusCode).
|
||||
Int64("bytes", rw.written).
|
||||
Dur("duration_ms", duration).
|
||||
Msg("HTTP request")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,546 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Store implements a file-based metadata store
|
||||
type Store struct {
|
||||
basePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config holds file store configuration
|
||||
type Config struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// New creates a new file-based metadata store
|
||||
func New(cfg Config) (*Store, error) {
|
||||
if cfg.Path == "" {
|
||||
cfg.Path = "./metadata"
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(cfg.Path, 0750); err != nil {
|
||||
return nil, fmt.Errorf("failed to create metadata directory: %w", err)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("path", cfg.Path).
|
||||
Msg("File-based metadata store initialized")
|
||||
|
||||
return &Store{
|
||||
basePath: cfg.Path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SavePackage saves package metadata
|
||||
func (s *Store) SavePackage(ctx context.Context, pkg *metadata.Package) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Create registry directory
|
||||
regDir := filepath.Join(s.basePath, pkg.Registry)
|
||||
if err := os.MkdirAll(regDir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save to file
|
||||
filename := filepath.Join(regDir, fmt.Sprintf("%s-%s.json", pkg.Name, pkg.Version))
|
||||
data, err := json.MarshalIndent(pkg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, data, 0600)
|
||||
}
|
||||
|
||||
// GetPackage retrieves package metadata
|
||||
func (s *Store) GetPackage(ctx context.Context, registry, name, version string) (*metadata.Package, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filename := filepath.Join(s.basePath, registry, fmt.Sprintf("%s-%s.json", name, version))
|
||||
data, err := os.ReadFile(filename) // #nosec G304 -- Filename is from internal registry structure
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pkg metadata.Package
|
||||
if err := json.Unmarshal(data, &pkg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pkg, nil
|
||||
}
|
||||
|
||||
// ListPackages lists all packages
|
||||
func (s *Store) ListPackages(ctx context.Context, opts *metadata.ListOptions) ([]*metadata.Package, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var packages []*metadata.Package
|
||||
|
||||
// Walk through all files
|
||||
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() || filepath.Ext(path) != ".json" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path from internal file structure
|
||||
if err != nil {
|
||||
return nil // Skip files we can't read
|
||||
}
|
||||
|
||||
var pkg metadata.Package
|
||||
if err := json.Unmarshal(data, &pkg); err != nil {
|
||||
return nil // Skip invalid JSON
|
||||
}
|
||||
|
||||
packages = append(packages, &pkg)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply pagination if options provided
|
||||
if opts != nil {
|
||||
if opts.Offset >= len(packages) {
|
||||
return []*metadata.Package{}, nil
|
||||
}
|
||||
|
||||
end := opts.Offset + opts.Limit
|
||||
if end > len(packages) {
|
||||
end = len(packages)
|
||||
}
|
||||
|
||||
return packages[opts.Offset:end], nil
|
||||
}
|
||||
|
||||
return packages, nil
|
||||
}
|
||||
|
||||
// DeletePackage deletes package metadata
|
||||
func (s *Store) DeletePackage(ctx context.Context, registry, name, version string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filename := filepath.Join(s.basePath, registry, fmt.Sprintf("%s-%s.json", name, version))
|
||||
if err := os.Remove(filename); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveScanResult saves scan result
|
||||
func (s *Store) SaveScanResult(ctx context.Context, result *metadata.ScanResult) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Create scans directory
|
||||
scanDir := filepath.Join(s.basePath, "scans", result.Registry, result.PackageName)
|
||||
if err := os.MkdirAll(scanDir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save to file with timestamp
|
||||
timestamp := time.Now().Unix()
|
||||
filename := filepath.Join(scanDir, fmt.Sprintf("%s-%d.json", result.PackageVersion, timestamp))
|
||||
data, err := json.MarshalIndent(result, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, data, 0600)
|
||||
}
|
||||
|
||||
// UpdateDownloadCount increments download counter
|
||||
func (s *Store) UpdateDownloadCount(ctx context.Context, registry, name, version string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Load package
|
||||
pkg, err := s.GetPackage(ctx, registry, name, version)
|
||||
if err != nil || pkg == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment counter
|
||||
pkg.DownloadCount++
|
||||
pkg.LastAccessed = time.Now()
|
||||
|
||||
// Save back
|
||||
return s.SavePackage(ctx, pkg)
|
||||
}
|
||||
|
||||
// GetStats returns statistics for a registry
|
||||
func (s *Store) GetStats(ctx context.Context, registry string) (*metadata.Stats, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
stats := &metadata.Stats{
|
||||
Registry: registry,
|
||||
LastUpdated: time.Now(),
|
||||
}
|
||||
|
||||
// Walk through files and calculate stats
|
||||
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() || filepath.Ext(path) != ".json" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path from internal file structure
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pkg metadata.Package
|
||||
if err := json.Unmarshal(data, &pkg); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter by registry if specified
|
||||
if registry != "" && pkg.Registry != registry {
|
||||
return nil
|
||||
}
|
||||
|
||||
stats.TotalPackages++
|
||||
stats.TotalSize += pkg.Size
|
||||
stats.TotalDownloads += pkg.DownloadCount
|
||||
|
||||
if pkg.SecurityScanned {
|
||||
stats.ScannedPackages++
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetScanResult retrieves latest scan result
|
||||
func (s *Store) GetScanResult(ctx context.Context, registry, name, version string) (*metadata.ScanResult, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
scanDir := filepath.Join(s.basePath, "scans", registry, name)
|
||||
pattern := filepath.Join(scanDir, fmt.Sprintf("%s-*.json", version))
|
||||
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Get the latest file
|
||||
latestFile := matches[len(matches)-1]
|
||||
data, err := os.ReadFile(latestFile) // #nosec G304 -- Path from glob match on internal structure
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result metadata.ScanResult
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Count returns total number of packages
|
||||
func (s *Store) Count(ctx context.Context) (int, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() && filepath.Ext(path) == ".json" && filepath.Dir(path) != filepath.Join(s.basePath, "scans") {
|
||||
count++
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Health checks if the store is healthy
|
||||
func (s *Store) Health(ctx context.Context) error {
|
||||
// Check if directory is accessible
|
||||
_, err := os.Stat(s.basePath)
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveCVEBypass saves a CVE bypass (admin only)
|
||||
func (s *Store) SaveCVEBypass(ctx context.Context, bypass *metadata.CVEBypass) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Create bypasses directory
|
||||
bypassesDir := filepath.Join(s.basePath, "bypasses")
|
||||
if err := os.MkdirAll(bypassesDir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save to file
|
||||
filename := filepath.Join(bypassesDir, fmt.Sprintf("%s.json", bypass.ID))
|
||||
data, err := json.MarshalIndent(bypass, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filename, data, 0600)
|
||||
}
|
||||
|
||||
// GetActiveCVEBypasses retrieves all active (non-expired) CVE bypasses
|
||||
func (s *Store) GetActiveCVEBypasses(ctx context.Context) ([]*metadata.CVEBypass, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
bypassesDir := filepath.Join(s.basePath, "bypasses")
|
||||
var bypasses []*metadata.CVEBypass
|
||||
now := time.Now()
|
||||
|
||||
// Read all bypass files
|
||||
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // bypasses directory doesn't exist yet
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() || filepath.Ext(path) != ".json" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path from internal file structure
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bypass metadata.CVEBypass
|
||||
if err := json.Unmarshal(data, &bypass); err != nil {
|
||||
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only include active and non-expired bypasses
|
||||
if bypass.Active && bypass.ExpiresAt.After(now) {
|
||||
bypasses = append(bypasses, &bypass)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bypasses, nil
|
||||
}
|
||||
|
||||
// ListCVEBypasses lists all CVE bypasses (including expired)
|
||||
func (s *Store) ListCVEBypasses(ctx context.Context, opts *metadata.BypassListOptions) ([]*metadata.CVEBypass, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
bypassesDir := filepath.Join(s.basePath, "bypasses")
|
||||
var bypasses []*metadata.CVEBypass
|
||||
now := time.Now()
|
||||
|
||||
// Read all bypass files
|
||||
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // bypasses directory doesn't exist yet
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() || filepath.Ext(path) != ".json" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path from internal file structure
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bypass metadata.CVEBypass
|
||||
if err := json.Unmarshal(data, &bypass); err != nil {
|
||||
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply filters if options provided
|
||||
if opts != nil {
|
||||
if opts.Type != "" && bypass.Type != opts.Type {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !opts.IncludeExpired && bypass.ExpiresAt.Before(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.ActiveOnly && !bypass.Active {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
bypasses = append(bypasses, &bypass)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply limit and offset if specified
|
||||
if opts != nil {
|
||||
if opts.Offset > 0 && opts.Offset < len(bypasses) {
|
||||
bypasses = bypasses[opts.Offset:]
|
||||
} else if opts.Offset >= len(bypasses) {
|
||||
return []*metadata.CVEBypass{}, nil
|
||||
}
|
||||
|
||||
if opts.Limit > 0 && opts.Limit < len(bypasses) {
|
||||
bypasses = bypasses[:opts.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
return bypasses, nil
|
||||
}
|
||||
|
||||
// DeleteCVEBypass deletes a CVE bypass by ID
|
||||
func (s *Store) DeleteCVEBypass(ctx context.Context, id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filename := filepath.Join(s.basePath, "bypasses", fmt.Sprintf("%s.json", id))
|
||||
err := os.Remove(filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fmt.Errorf("CVE bypass not found: %s", id)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredBypasses removes expired bypasses
|
||||
func (s *Store) CleanupExpiredBypasses(ctx context.Context) (int, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
bypassesDir := filepath.Join(s.basePath, "bypasses")
|
||||
count := 0
|
||||
now := time.Now()
|
||||
|
||||
// Read all bypass files
|
||||
err := filepath.Walk(bypassesDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // bypasses directory doesn't exist yet
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() || filepath.Ext(path) != ".json" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path from internal file structure
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bypass metadata.CVEBypass
|
||||
if err := json.Unmarshal(data, &bypass); err != nil {
|
||||
log.Warn().Err(err).Str("file", path).Msg("Failed to unmarshal bypass")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete if expired
|
||||
if bypass.ExpiresAt.Before(now) {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Warn().Err(err).Str("file", path).Msg("Failed to delete expired bypass")
|
||||
} else {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetTimeSeriesStats returns time-series download statistics
|
||||
// File-based store doesn't support time-series statistics
|
||||
func (s *Store) GetTimeSeriesStats(ctx context.Context, period string, registry string) (*metadata.TimeSeriesStats, error) {
|
||||
// Return empty time-series data for file-based store
|
||||
return &metadata.TimeSeriesStats{
|
||||
Period: period,
|
||||
Registry: registry,
|
||||
DataPoints: []*metadata.TimeSeriesDataPoint{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AggregateDownloadData aggregates download data
|
||||
// File-based store doesn't support aggregation
|
||||
func (s *Store) AggregateDownloadData(ctx context.Context) error {
|
||||
// No-op for file-based store
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the store
|
||||
func (s *Store) Close() error {
|
||||
// Nothing to close for file-based store
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Store is an alias for MetadataStore for convenience
|
||||
type Store = MetadataStore
|
||||
|
||||
// MetadataStore defines the interface for package metadata storage
|
||||
type MetadataStore interface {
|
||||
// SavePackage saves package metadata
|
||||
SavePackage(ctx context.Context, pkg *Package) error
|
||||
|
||||
// GetPackage retrieves package metadata
|
||||
GetPackage(ctx context.Context, registry, name, version string) (*Package, error)
|
||||
|
||||
// DeletePackage deletes package metadata
|
||||
DeletePackage(ctx context.Context, registry, name, version string) error
|
||||
|
||||
// ListPackages lists packages with optional filtering
|
||||
ListPackages(ctx context.Context, opts *ListOptions) ([]*Package, error)
|
||||
|
||||
// UpdateDownloadCount increments download counter
|
||||
UpdateDownloadCount(ctx context.Context, registry, name, version string) error
|
||||
|
||||
// GetStats returns statistics
|
||||
GetStats(ctx context.Context, registry string) (*Stats, error)
|
||||
|
||||
// SaveScanResult saves security scan result
|
||||
SaveScanResult(ctx context.Context, result *ScanResult) error
|
||||
|
||||
// GetScanResult retrieves security scan result
|
||||
GetScanResult(ctx context.Context, registry, name, version string) (*ScanResult, error)
|
||||
|
||||
// SaveCVEBypass saves a CVE bypass (admin only)
|
||||
SaveCVEBypass(ctx context.Context, bypass *CVEBypass) error
|
||||
|
||||
// GetActiveCVEBypasses retrieves all active (non-expired) CVE bypasses
|
||||
GetActiveCVEBypasses(ctx context.Context) ([]*CVEBypass, error)
|
||||
|
||||
// ListCVEBypasses lists all CVE bypasses (including expired)
|
||||
ListCVEBypasses(ctx context.Context, opts *BypassListOptions) ([]*CVEBypass, error)
|
||||
|
||||
// DeleteCVEBypass deletes a CVE bypass by ID
|
||||
DeleteCVEBypass(ctx context.Context, id string) error
|
||||
|
||||
// CleanupExpiredBypasses removes expired bypasses
|
||||
CleanupExpiredBypasses(ctx context.Context) (int, error)
|
||||
|
||||
// Count returns total number of packages
|
||||
Count(ctx context.Context) (int, error)
|
||||
|
||||
// Health checks metadata store health
|
||||
Health(ctx context.Context) error
|
||||
|
||||
// GetTimeSeriesStats returns time-series download statistics
|
||||
GetTimeSeriesStats(ctx context.Context, period string, registry string) (*TimeSeriesStats, error)
|
||||
|
||||
// AggregateDownloadData aggregates raw download events and cleans up old data
|
||||
AggregateDownloadData(ctx context.Context) error
|
||||
|
||||
// Close closes the metadata store
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Package represents package metadata
|
||||
type Package struct {
|
||||
ID string `json:"id"`
|
||||
Registry string `json:"registry"` // npm, pypi, go
|
||||
Name string `json:"name"` // Package name
|
||||
Version string `json:"version"` // Package version
|
||||
StorageKey string `json:"storage_key"` // Key in storage backend
|
||||
Size int64 `json:"size"` // Package size in bytes
|
||||
ChecksumMD5 string `json:"checksum_md5"` // MD5 checksum
|
||||
ChecksumSHA256 string `json:"checksum_sha256"` // SHA256 checksum
|
||||
UpstreamURL string `json:"upstream_url"` // Original upstream URL
|
||||
CachedAt time.Time `json:"cached_at"` // When cached
|
||||
LastAccessed time.Time `json:"last_accessed"` // Last access time
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never)
|
||||
DownloadCount int64 `json:"download_count"` // Download counter
|
||||
Metadata map[string]string `json:"metadata"` // Additional metadata
|
||||
SecurityScanned bool `json:"security_scanned"` // Has been scanned
|
||||
RequiresAuth bool `json:"requires_auth"` // Package requires authentication
|
||||
AuthProvider string `json:"auth_provider"` // Auth provider (github.com, npm.pkg.github.com, etc.)
|
||||
}
|
||||
|
||||
// ScanResult represents a security scan result
|
||||
type ScanResult struct {
|
||||
ID string `json:"id"`
|
||||
Registry string `json:"registry"`
|
||||
PackageName string `json:"package_name"`
|
||||
PackageVersion string `json:"package_version"`
|
||||
Scanner string `json:"scanner"` // trivy, osv, etc.
|
||||
ScannedAt time.Time `json:"scanned_at"`
|
||||
Status ScanStatus `json:"status"` // clean, vulnerable, error
|
||||
VulnerabilityCount int `json:"vulnerability_count"`
|
||||
Vulnerabilities []Vulnerability `json:"vulnerabilities"`
|
||||
Details map[string]interface{} `json:"details"` // Scanner-specific details
|
||||
}
|
||||
|
||||
// Vulnerability represents a security vulnerability
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"` // CVE-xxx, GHSA-xxx, etc.
|
||||
Severity string `json:"severity"` // critical, high, moderate, low
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
References []string `json:"references"`
|
||||
FixedIn string `json:"fixed_in"` // Version where fixed
|
||||
DetectedBy []string `json:"detected_by,omitempty"` // List of scanners that detected this vulnerability
|
||||
}
|
||||
|
||||
// NormalizeSeverity normalizes severity names to standard values
|
||||
// Ensures consistent naming: CRITICAL, HIGH, MODERATE, LOW
|
||||
func NormalizeSeverity(severity string) string {
|
||||
normalized := strings.ToUpper(strings.TrimSpace(severity))
|
||||
|
||||
// Map MEDIUM to MODERATE for consistency
|
||||
if normalized == "MEDIUM" {
|
||||
return "MODERATE"
|
||||
}
|
||||
|
||||
// Ensure we only return valid severity levels
|
||||
switch normalized {
|
||||
case "CRITICAL", "HIGH", "MODERATE", "LOW":
|
||||
return normalized
|
||||
default:
|
||||
return "LOW" // Default unknown severities to LOW
|
||||
}
|
||||
}
|
||||
|
||||
// ScanStatus represents scan result status
|
||||
type ScanStatus string
|
||||
|
||||
const (
|
||||
ScanStatusClean ScanStatus = "clean"
|
||||
ScanStatusVulnerable ScanStatus = "vulnerable"
|
||||
ScanStatusError ScanStatus = "error"
|
||||
ScanStatusPending ScanStatus = "pending"
|
||||
)
|
||||
|
||||
// Stats represents metadata statistics
|
||||
type Stats struct {
|
||||
Registry string `json:"registry"`
|
||||
TotalPackages int64 `json:"total_packages"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
TotalDownloads int64 `json:"total_downloads"`
|
||||
ScannedPackages int64 `json:"scanned_packages"`
|
||||
VulnerablePackages int64 `json:"vulnerable_packages"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
}
|
||||
|
||||
// TimeSeriesDataPoint represents a single data point in time-series
|
||||
type TimeSeriesDataPoint struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Value int64 `json:"value"`
|
||||
}
|
||||
|
||||
// TimeSeriesStats represents time-series download statistics
|
||||
type TimeSeriesStats struct {
|
||||
Period string `json:"period"` // 1h, 1day, 7day, 30day
|
||||
Registry string `json:"registry"` // empty string for all registries
|
||||
DataPoints []*TimeSeriesDataPoint `json:"data_points"`
|
||||
}
|
||||
|
||||
// CVEBypass represents a temporary bypass for a CVE or package
|
||||
type CVEBypass struct {
|
||||
ID string `json:"id"` // Unique bypass ID
|
||||
Type BypassType `json:"type"` // cve, package
|
||||
Target string `json:"target"` // CVE ID (e.g., "CVE-2021-23337") or package (e.g., "npm/lodash@4.17.20")
|
||||
Reason string `json:"reason"` // Why this bypass was created
|
||||
CreatedBy string `json:"created_by"` // Admin user who created it
|
||||
CreatedAt time.Time `json:"created_at"` // When created
|
||||
ExpiresAt time.Time `json:"expires_at"` // When it expires
|
||||
AppliesTo string `json:"applies_to,omitempty"` // Optional: limit to specific package (for CVE bypasses)
|
||||
NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired
|
||||
Active bool `json:"active"` // Can be deactivated without deletion
|
||||
}
|
||||
|
||||
// BypassType represents the type of bypass
|
||||
type BypassType string
|
||||
|
||||
const (
|
||||
BypassTypeCVE BypassType = "cve" // Bypass specific CVE
|
||||
BypassTypePackage BypassType = "package" // Bypass entire package
|
||||
)
|
||||
|
||||
// BypassListOptions contains options for listing CVE bypasses
|
||||
type BypassListOptions struct {
|
||||
Type BypassType // Filter by type
|
||||
IncludeExpired bool // Include expired bypasses
|
||||
ActiveOnly bool // Only active bypasses
|
||||
Limit int // Max results
|
||||
Offset int // Pagination offset
|
||||
}
|
||||
|
||||
// ListOptions contains options for listing packages
|
||||
type ListOptions struct {
|
||||
Registry string // Filter by registry
|
||||
NamePrefix string // Filter by name prefix
|
||||
MinSize int64 // Minimum package size
|
||||
MaxSize int64 // Maximum package size
|
||||
ScannedOnly bool // Only scanned packages
|
||||
SinceDate time.Time // Packages cached since date
|
||||
Limit int // Max results
|
||||
Offset int // Pagination offset
|
||||
SortBy string // Sort field (name, size, cached_at, download_count)
|
||||
SortDesc bool // Sort descending
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,188 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
var (
|
||||
// HTTP metrics
|
||||
HTTPRequestsTotal = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_http_requests_total",
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"handler", "method", "status"},
|
||||
)
|
||||
|
||||
HTTPRequestDuration = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "gohoarder_http_request_duration_seconds",
|
||||
Help: "HTTP request duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"handler", "method"},
|
||||
)
|
||||
|
||||
// Cache metrics
|
||||
CacheRequests = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_cache_requests_total",
|
||||
Help: "Total number of cache requests",
|
||||
},
|
||||
[]string{"status", "handler"}, // hit, miss, error
|
||||
)
|
||||
|
||||
CacheSizeBytes = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "gohoarder_cache_size_bytes",
|
||||
Help: "Current cache size in bytes",
|
||||
},
|
||||
[]string{"backend"},
|
||||
)
|
||||
|
||||
CacheItemsTotal = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "gohoarder_cache_items_total",
|
||||
Help: "Total number of cached items",
|
||||
},
|
||||
[]string{"handler"},
|
||||
)
|
||||
|
||||
CacheEvictions = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_cache_evictions_total",
|
||||
Help: "Total number of cache evictions",
|
||||
},
|
||||
[]string{"reason"}, // ttl, lru, manual
|
||||
)
|
||||
|
||||
// Storage metrics
|
||||
StorageOperations = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_storage_operations_total",
|
||||
Help: "Total number of storage operations",
|
||||
},
|
||||
[]string{"backend", "operation", "status"}, // get, put, delete
|
||||
)
|
||||
|
||||
StorageQuotaBytes = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "gohoarder_storage_quota_bytes",
|
||||
Help: "Storage quota in bytes per project",
|
||||
},
|
||||
[]string{"project"},
|
||||
)
|
||||
|
||||
// Upstream metrics
|
||||
UpstreamRequests = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_upstream_requests_total",
|
||||
Help: "Total number of upstream requests",
|
||||
},
|
||||
[]string{"registry", "status"},
|
||||
)
|
||||
|
||||
UpstreamDuration = promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "gohoarder_upstream_duration_seconds",
|
||||
Help: "Upstream request duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"registry"},
|
||||
)
|
||||
|
||||
// Security metrics
|
||||
SecurityScans = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_security_scans_total",
|
||||
Help: "Total number of security scans",
|
||||
},
|
||||
[]string{"scanner", "result"}, // clean, blocked, error
|
||||
)
|
||||
|
||||
VulnerabilitiesFound = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "gohoarder_vulnerabilities_found_total",
|
||||
Help: "Total number of vulnerabilities found",
|
||||
},
|
||||
[]string{"severity"}, // low, medium, high, critical
|
||||
)
|
||||
|
||||
// Circuit breaker metrics
|
||||
CircuitBreakerState = promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "gohoarder_circuit_breaker_state",
|
||||
Help: "Circuit breaker state (0=closed, 1=open, 2=half-open)",
|
||||
},
|
||||
[]string{"name"},
|
||||
)
|
||||
)
|
||||
|
||||
// Handler returns the Prometheus HTTP handler
|
||||
func Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
}
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
func RecordCacheHit(handler string) {
|
||||
CacheRequests.WithLabelValues("hit", handler).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheMiss records a cache miss
|
||||
func RecordCacheMiss(handler string) {
|
||||
CacheRequests.WithLabelValues("miss", handler).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheError records a cache error
|
||||
func RecordCacheError(handler string) {
|
||||
CacheRequests.WithLabelValues("error", handler).Inc()
|
||||
}
|
||||
|
||||
// UpdateCacheSize updates the cache size metric
|
||||
func UpdateCacheSize(backend string, bytes int64) {
|
||||
CacheSizeBytes.WithLabelValues(backend).Set(float64(bytes))
|
||||
}
|
||||
|
||||
// UpdateCacheItems updates the cache items metric
|
||||
func UpdateCacheItems(handler string, count int64) {
|
||||
CacheItemsTotal.WithLabelValues(handler).Set(float64(count))
|
||||
}
|
||||
|
||||
// RecordCacheEviction records a cache eviction
|
||||
func RecordCacheEviction(reason string) {
|
||||
CacheEvictions.WithLabelValues(reason).Inc()
|
||||
}
|
||||
|
||||
// RecordStorageOperation records a storage operation
|
||||
func RecordStorageOperation(backend, operation, status string) {
|
||||
StorageOperations.WithLabelValues(backend, operation, status).Inc()
|
||||
}
|
||||
|
||||
// UpdateStorageQuota updates the storage quota metric
|
||||
func UpdateStorageQuota(project string, bytes int64) {
|
||||
StorageQuotaBytes.WithLabelValues(project).Set(float64(bytes))
|
||||
}
|
||||
|
||||
// RecordUpstreamRequest records an upstream request
|
||||
func RecordUpstreamRequest(registry, status string) {
|
||||
UpstreamRequests.WithLabelValues(registry, status).Inc()
|
||||
}
|
||||
|
||||
// RecordSecurityScan records a security scan
|
||||
func RecordSecurityScan(scanner, result string) {
|
||||
SecurityScans.WithLabelValues(scanner, result).Inc()
|
||||
}
|
||||
|
||||
// RecordVulnerability records a vulnerability finding
|
||||
func RecordVulnerability(severity string) {
|
||||
VulnerabilitiesFound.WithLabelValues(severity).Inc()
|
||||
}
|
||||
|
||||
// UpdateCircuitBreakerState updates the circuit breaker state
|
||||
func UpdateCircuitBreakerState(name string, state int) {
|
||||
CircuitBreakerState.WithLabelValues(name).Set(float64(state))
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Client is an HTTP client with resilience features
|
||||
type Client struct {
|
||||
client *http.Client
|
||||
rateLimiter *rate.Limiter
|
||||
circuitBreaker *CircuitBreaker
|
||||
retryConfig RetryConfig
|
||||
}
|
||||
|
||||
// Config holds client configuration
|
||||
type Config struct {
|
||||
Timeout time.Duration // Request timeout
|
||||
MaxRetries int // Max retry attempts
|
||||
RetryDelay time.Duration // Initial retry delay
|
||||
RateLimit float64 // Requests per second (0 = unlimited)
|
||||
RateBurst int // Rate limiter burst
|
||||
CircuitBreaker CircuitBreakerConfig
|
||||
UserAgent string
|
||||
MaxConnsPerHost int
|
||||
}
|
||||
|
||||
// RetryConfig holds retry configuration
|
||||
type RetryConfig struct {
|
||||
MaxAttempts int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
FixedDelays []time.Duration // If set, use these delays instead of exponential backoff
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds circuit breaker configuration
|
||||
type CircuitBreakerConfig struct {
|
||||
Enabled bool
|
||||
FailureThreshold int // Failures before opening
|
||||
SuccessThreshold int // Successes before closing
|
||||
Timeout time.Duration // How long to stay open
|
||||
HalfOpenMaxCalls int // Max calls in half-open state
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents circuit breaker state
|
||||
type CircuitBreakerState int
|
||||
|
||||
const (
|
||||
StateClosed CircuitBreakerState = iota
|
||||
StateOpen
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config CircuitBreakerConfig
|
||||
state CircuitBreakerState
|
||||
failures int
|
||||
successes int
|
||||
lastFailureTime time.Time
|
||||
halfOpenCalls int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new HTTP client with resilience features
|
||||
func NewClient(config Config) *Client {
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
if config.MaxRetries == 0 {
|
||||
config.MaxRetries = 3
|
||||
}
|
||||
|
||||
if config.RetryDelay == 0 {
|
||||
config.RetryDelay = 1 * time.Second
|
||||
}
|
||||
|
||||
if config.UserAgent == "" {
|
||||
config.UserAgent = "GoHoarder/1.0"
|
||||
}
|
||||
|
||||
if config.MaxConnsPerHost == 0 {
|
||||
config.MaxConnsPerHost = 100
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: config.MaxConnsPerHost,
|
||||
MaxConnsPerHost: config.MaxConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DisableCompression: false,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.Timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
var rateLimiter *rate.Limiter
|
||||
if config.RateLimit > 0 {
|
||||
if config.RateBurst == 0 {
|
||||
config.RateBurst = int(config.RateLimit)
|
||||
}
|
||||
rateLimiter = rate.NewLimiter(rate.Limit(config.RateLimit), config.RateBurst)
|
||||
}
|
||||
|
||||
var cb *CircuitBreaker
|
||||
if config.CircuitBreaker.Enabled {
|
||||
if config.CircuitBreaker.FailureThreshold == 0 {
|
||||
config.CircuitBreaker.FailureThreshold = 5
|
||||
}
|
||||
if config.CircuitBreaker.SuccessThreshold == 0 {
|
||||
config.CircuitBreaker.SuccessThreshold = 2
|
||||
}
|
||||
if config.CircuitBreaker.Timeout == 0 {
|
||||
config.CircuitBreaker.Timeout = 60 * time.Second
|
||||
}
|
||||
if config.CircuitBreaker.HalfOpenMaxCalls == 0 {
|
||||
config.CircuitBreaker.HalfOpenMaxCalls = 3
|
||||
}
|
||||
|
||||
cb = &CircuitBreaker{
|
||||
config: config.CircuitBreaker,
|
||||
state: StateClosed,
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: httpClient,
|
||||
rateLimiter: rateLimiter,
|
||||
circuitBreaker: cb,
|
||||
retryConfig: RetryConfig{
|
||||
MaxAttempts: config.MaxRetries,
|
||||
InitialDelay: config.RetryDelay,
|
||||
MaxDelay: 30 * time.Second,
|
||||
Multiplier: 2.0,
|
||||
// Fixed delays: 1s, 5s, 10s for retry attempts 1, 2, 3
|
||||
FixedDelays: []time.Duration{1 * time.Second, 5 * time.Second, 10 * time.Second},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get performs a GET request with resilience features
|
||||
func (c *Client) Get(ctx context.Context, url string, headers map[string]string) (io.ReadCloser, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, errors.Wrap(err, errors.ErrCodeUpstreamError, "failed to create request")
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.do(ctx, req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return resp.Body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// do executes an HTTP request with retries and circuit breaker
|
||||
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
// Check circuit breaker
|
||||
if c.circuitBreaker != nil {
|
||||
if !c.circuitBreaker.AllowRequest() {
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
|
||||
return nil, errors.New(errors.ErrCodeCircuitOpen, "circuit breaker is open")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply rate limiting
|
||||
if c.rateLimiter != nil {
|
||||
if err := c.rateLimiter.Wait(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeRateLimited, "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
// Execute with retries
|
||||
var lastErr error
|
||||
delay := c.retryConfig.InitialDelay
|
||||
|
||||
for attempt := 0; attempt < c.retryConfig.MaxAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Calculate delay: use fixed delays if configured, otherwise exponential backoff
|
||||
if len(c.retryConfig.FixedDelays) > 0 {
|
||||
// Use fixed delay schedule
|
||||
delayIndex := attempt - 1
|
||||
if delayIndex < len(c.retryConfig.FixedDelays) {
|
||||
delay = c.retryConfig.FixedDelays[delayIndex]
|
||||
} else {
|
||||
// Use last delay if we run out of configured delays
|
||||
delay = c.retryConfig.FixedDelays[len(c.retryConfig.FixedDelays)-1]
|
||||
}
|
||||
} else {
|
||||
// Exponential backoff
|
||||
delay = time.Duration(float64(delay) * c.retryConfig.Multiplier)
|
||||
if delay > c.retryConfig.MaxDelay {
|
||||
delay = c.retryConfig.MaxDelay
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("url", req.URL.String()).
|
||||
Int("attempt", attempt+1).
|
||||
Dur("delay", delay).
|
||||
Msg("Retrying request")
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if c.circuitBreaker != nil {
|
||||
c.circuitBreaker.RecordFailure()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if response is retryable
|
||||
if c.isRetryable(resp.StatusCode) {
|
||||
resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
lastErr = fmt.Errorf("received retryable status code: %d", resp.StatusCode)
|
||||
if c.circuitBreaker != nil {
|
||||
c.circuitBreaker.RecordFailure()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Success
|
||||
if c.circuitBreaker != nil {
|
||||
c.circuitBreaker.RecordSuccess()
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateClosed))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// All retries exhausted
|
||||
if c.circuitBreaker != nil {
|
||||
c.circuitBreaker.RecordFailure()
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, errors.Wrap(lastErr, errors.ErrCodeUpstreamFailure, "all retry attempts failed")
|
||||
}
|
||||
|
||||
return nil, errors.New(errors.ErrCodeUpstreamFailure, "request failed without error")
|
||||
}
|
||||
|
||||
// isRetryable checks if a status code should trigger a retry
|
||||
func (c *Client) isRetryable(statusCode int) bool {
|
||||
// Retry on server errors and some client errors
|
||||
return statusCode >= 500 || statusCode == 408 || statusCode == 429
|
||||
}
|
||||
|
||||
// AllowRequest checks if a request is allowed by the circuit breaker
|
||||
func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// Check if timeout has elapsed
|
||||
if time.Since(cb.lastFailureTime) > cb.config.Timeout {
|
||||
cb.state = StateHalfOpen
|
||||
cb.halfOpenCalls = 0
|
||||
cb.successes = 0
|
||||
log.Info().Msg("Circuit breaker transitioning to half-open")
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateHalfOpen))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
if cb.halfOpenCalls < cb.config.HalfOpenMaxCalls {
|
||||
cb.halfOpenCalls++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful request
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case StateClosed:
|
||||
cb.failures = 0
|
||||
|
||||
case StateHalfOpen:
|
||||
cb.successes++
|
||||
if cb.successes >= cb.config.SuccessThreshold {
|
||||
cb.state = StateClosed
|
||||
cb.failures = 0
|
||||
cb.successes = 0
|
||||
cb.halfOpenCalls = 0
|
||||
log.Info().Msg("Circuit breaker closed")
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateClosed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure records a failed request
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case StateClosed:
|
||||
cb.failures++
|
||||
if cb.failures >= cb.config.FailureThreshold {
|
||||
cb.state = StateOpen
|
||||
log.Warn().Int("failures", cb.failures).Msg("Circuit breaker opened")
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
// Single failure in half-open returns to open
|
||||
cb.state = StateOpen
|
||||
cb.halfOpenCalls = 0
|
||||
cb.successes = 0
|
||||
log.Warn().Msg("Circuit breaker re-opened from half-open")
|
||||
metrics.UpdateCircuitBreakerState("upstream", int(StateOpen))
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current circuit breaker state
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mu.RLock()
|
||||
defer cb.mu.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
package network_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestClientGet tests the HTTP client Get method with various scenarios
|
||||
func TestClientGet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverBehavior func(*testing.T) *httptest.Server
|
||||
config network.Config
|
||||
headers map[string]string
|
||||
wantErr bool
|
||||
errContains string
|
||||
validateBody func(*testing.T, io.ReadCloser)
|
||||
validateStatus func(*testing.T, int)
|
||||
}{
|
||||
// GOOD: Successful GET request
|
||||
{
|
||||
name: "successful get request returns body",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("success")) // #nosec G104 -- Websocket buffer write
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 3,
|
||||
},
|
||||
validateBody: func(t *testing.T, body io.ReadCloser) {
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "success", string(data))
|
||||
},
|
||||
validateStatus: func(t *testing.T, status int) {
|
||||
assert.Equal(t, http.StatusOK, status)
|
||||
},
|
||||
},
|
||||
// GOOD: Retry succeeds on second attempt
|
||||
{
|
||||
name: "retry succeeds after transient failure",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
var attemptCount int32
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
count := atomic.AddInt32(&attemptCount, 1)
|
||||
if count == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("retry-success")) // #nosec G104 -- Websocket buffer write
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 10 * time.Millisecond,
|
||||
},
|
||||
validateBody: func(t *testing.T, body io.ReadCloser) {
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "retry-success", string(data))
|
||||
},
|
||||
validateStatus: func(t *testing.T, status int) {
|
||||
assert.Equal(t, http.StatusOK, status)
|
||||
},
|
||||
},
|
||||
// GOOD: Headers are properly sent
|
||||
{
|
||||
name: "custom headers are sent correctly",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "application/json", r.Header.Get("Accept"))
|
||||
assert.Equal(t, "Bearer token123", r.Header.Get("Authorization"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 1,
|
||||
},
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer token123",
|
||||
},
|
||||
validateStatus: func(t *testing.T, status int) {
|
||||
assert.Equal(t, http.StatusOK, status)
|
||||
},
|
||||
},
|
||||
// WRONG: Server returns 404 (non-retryable)
|
||||
{
|
||||
name: "404 error is not retried",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
var attemptCount int32
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&attemptCount, 1)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 10 * time.Millisecond,
|
||||
},
|
||||
validateStatus: func(t *testing.T, status int) {
|
||||
assert.Equal(t, http.StatusNotFound, status)
|
||||
},
|
||||
},
|
||||
// WRONG: Server returns 429 (rate limited - retryable)
|
||||
{
|
||||
name: "429 rate limit triggers retry with fixed delays",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
var attemptCount int32
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
count := atomic.AddInt32(&attemptCount, 1)
|
||||
if count <= 2 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("success-after-rate-limit")) // #nosec G104 -- Websocket buffer write
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 10 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 10 * time.Millisecond,
|
||||
},
|
||||
validateBody: func(t *testing.T, body io.ReadCloser) {
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "success-after-rate-limit", string(data))
|
||||
},
|
||||
},
|
||||
// BAD: All retries exhausted
|
||||
{
|
||||
name: "all retries fail returns error",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 2,
|
||||
RetryDelay: 10 * time.Millisecond,
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "retry attempts failed",
|
||||
},
|
||||
// BAD: Server timeout
|
||||
{
|
||||
name: "server timeout returns error",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 50 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "context deadline exceeded",
|
||||
},
|
||||
// EDGE 1: Context timeout (deadline exceeded)
|
||||
{
|
||||
name: "context timeout stops retry",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 5,
|
||||
RetryDelay: 50 * time.Millisecond,
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "context deadline exceeded",
|
||||
},
|
||||
// EDGE 2: Empty response body
|
||||
{
|
||||
name: "empty response body handled correctly",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 1,
|
||||
},
|
||||
validateBody: func(t *testing.T, body io.ReadCloser) {
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, data)
|
||||
},
|
||||
},
|
||||
// EDGE 3: Large response body
|
||||
{
|
||||
name: "large response body handled correctly",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
largeBody := strings.Repeat("a", 1024*1024) // 1MB
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(largeBody)) // #nosec G104 -- Websocket buffer write
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 10 * time.Second,
|
||||
MaxRetries: 1,
|
||||
},
|
||||
validateBody: func(t *testing.T, body io.ReadCloser) {
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
data, err := io.ReadAll(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, data, 1024*1024)
|
||||
},
|
||||
},
|
||||
// EDGE 4: Circuit breaker enabled
|
||||
{
|
||||
name: "circuit breaker opens after failures",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 2,
|
||||
RetryDelay: 10 * time.Millisecond,
|
||||
CircuitBreaker: network.CircuitBreakerConfig{
|
||||
Enabled: true,
|
||||
FailureThreshold: 3,
|
||||
SuccessThreshold: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "retry attempts failed",
|
||||
},
|
||||
// EDGE 5: Rate limiting enabled
|
||||
{
|
||||
name: "rate limiter throttles requests",
|
||||
serverBehavior: func(t *testing.T) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
},
|
||||
config: network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 1,
|
||||
RateLimit: 10, // 10 req/sec
|
||||
RateBurst: 1,
|
||||
},
|
||||
validateStatus: func(t *testing.T, status int) {
|
||||
assert.Equal(t, http.StatusOK, status)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Arrange
|
||||
server := tt.serverBehavior(t)
|
||||
defer server.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
client := network.NewClient(tt.config)
|
||||
ctx := context.Background()
|
||||
|
||||
// For context timeout test
|
||||
if strings.Contains(tt.name, "context timeout") {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Act
|
||||
body, status, err := client.Get(ctx, server.URL, tt.headers)
|
||||
|
||||
// Assert
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, body)
|
||||
|
||||
if tt.validateBody != nil {
|
||||
tt.validateBody(t, body)
|
||||
} else {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
|
||||
if tt.validateStatus != nil {
|
||||
tt.validateStatus(t, status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRetryDelays verifies fixed retry delays are used correctly
|
||||
func TestRetryDelays(t *testing.T) {
|
||||
var attemptTimes []time.Time
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptTimes = append(attemptTimes, time.Now())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
client := network.NewClient(network.Config{
|
||||
Timeout: 10 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
_, _, err := client.Get(ctx, server.URL, nil)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Len(t, attemptTimes, 3, "should have made exactly 3 attempts")
|
||||
|
||||
// Verify delays are approximately 1s, 5s, 10s (with some tolerance)
|
||||
// Note: The actual implementation uses fixed delays [1s, 5s, 10s]
|
||||
// but for this test we're using RetryDelay as base which would be used
|
||||
// if FixedDelays wasn't set
|
||||
}
|
||||
|
||||
// TestConcurrentRequests verifies the client is safe for concurrent use
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("concurrent-ok")) // #nosec G104 -- Websocket buffer write
|
||||
}))
|
||||
defer server.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
client := network.NewClient(network.Config{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxRetries: 1,
|
||||
})
|
||||
|
||||
const concurrent = 10
|
||||
errs := make(chan error, concurrent)
|
||||
|
||||
// Launch concurrent requests
|
||||
for i := 0; i < concurrent; i++ {
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
body, status, err := client.Get(ctx, server.URL, nil)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if status != http.StatusOK {
|
||||
errs <- fmt.Errorf("unexpected status: %d", status)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
|
||||
if string(data) != "concurrent-ok" {
|
||||
errs <- fmt.Errorf("unexpected body: %s", data)
|
||||
return
|
||||
}
|
||||
|
||||
errs <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all to complete
|
||||
for i := 0; i < concurrent; i++ {
|
||||
err := <-errs
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package prewarming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// PackageInfo represents a package to pre-warm
|
||||
type PackageInfo struct {
|
||||
Registry string
|
||||
Name string
|
||||
Version string
|
||||
Priority int
|
||||
}
|
||||
|
||||
// Worker handles background pre-warming of popular packages
|
||||
type Worker struct {
|
||||
cache *cache.Manager
|
||||
analytics *analytics.Engine
|
||||
client *network.Client
|
||||
interval time.Duration
|
||||
maxConcurrent int
|
||||
enabled bool
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config holds pre-warming worker configuration
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Interval time.Duration
|
||||
MaxConcurrent int
|
||||
TopPackages int
|
||||
CacheManager *cache.Manager
|
||||
Analytics *analytics.Engine
|
||||
NetworkClient *network.Client
|
||||
}
|
||||
|
||||
// NewWorker creates a new pre-warming worker
|
||||
func NewWorker(cfg Config) *Worker {
|
||||
if cfg.Interval <= 0 {
|
||||
cfg.Interval = 1 * time.Hour
|
||||
}
|
||||
if cfg.MaxConcurrent <= 0 {
|
||||
cfg.MaxConcurrent = 5
|
||||
}
|
||||
if cfg.TopPackages <= 0 {
|
||||
cfg.TopPackages = 100
|
||||
}
|
||||
|
||||
worker := &Worker{
|
||||
cache: cfg.CacheManager,
|
||||
analytics: cfg.Analytics,
|
||||
client: cfg.NetworkClient,
|
||||
interval: cfg.Interval,
|
||||
maxConcurrent: cfg.MaxConcurrent,
|
||||
enabled: cfg.Enabled,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
if cfg.Enabled {
|
||||
log.Info().
|
||||
Dur("interval", cfg.Interval).
|
||||
Int("max_concurrent", cfg.MaxConcurrent).
|
||||
Msg("Pre-warming worker initialized")
|
||||
} else {
|
||||
log.Info().Msg("Pre-warming worker disabled")
|
||||
}
|
||||
|
||||
return worker
|
||||
}
|
||||
|
||||
// Start begins the pre-warming worker
|
||||
func (w *Worker) Start(ctx context.Context) {
|
||||
if !w.enabled {
|
||||
log.Debug().Msg("Pre-warming worker is disabled, not starting")
|
||||
return
|
||||
}
|
||||
|
||||
w.wg.Add(1)
|
||||
go w.run(ctx)
|
||||
log.Info().Msg("Pre-warming worker started")
|
||||
}
|
||||
|
||||
// run is the main worker loop
|
||||
func (w *Worker) run(ctx context.Context) {
|
||||
defer w.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on start
|
||||
w.prewarmPopularPackages(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Pre-warming worker stopping due to context cancellation")
|
||||
return
|
||||
case <-w.stopChan:
|
||||
log.Info().Msg("Pre-warming worker stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.prewarmPopularPackages(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prewarmPopularPackages fetches and caches popular packages
|
||||
func (w *Worker) prewarmPopularPackages(ctx context.Context) {
|
||||
log.Info().Msg("Starting pre-warming cycle")
|
||||
|
||||
// Get popular packages from analytics
|
||||
popularPackages := w.analytics.GetTopPackages(100)
|
||||
if len(popularPackages) == 0 {
|
||||
log.Debug().Msg("No popular packages found for pre-warming")
|
||||
return
|
||||
}
|
||||
|
||||
// Get trending packages for additional candidates
|
||||
trendingPackages := w.analytics.GetTrendingPackages(50)
|
||||
|
||||
// Combine and deduplicate
|
||||
packages := w.combinePackages(popularPackages, trendingPackages)
|
||||
|
||||
log.Info().
|
||||
Int("packages", len(packages)).
|
||||
Msg("Identified packages for pre-warming")
|
||||
|
||||
// Create work queue
|
||||
workChan := make(chan PackageInfo, len(packages))
|
||||
for _, pkg := range packages {
|
||||
workChan <- PackageInfo{
|
||||
Registry: pkg.Registry,
|
||||
Name: pkg.Name,
|
||||
Version: "latest", // Pre-warm latest version
|
||||
Priority: int(pkg.Downloads),
|
||||
}
|
||||
}
|
||||
close(workChan)
|
||||
|
||||
// Start worker goroutines
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < w.maxConcurrent; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
w.processPackages(ctx, workerID, workChan)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
log.Info().Msg("Pre-warming cycle completed")
|
||||
}
|
||||
|
||||
// processPackages processes packages from the work queue
|
||||
func (w *Worker) processPackages(ctx context.Context, workerID int, workChan <-chan PackageInfo) {
|
||||
for pkg := range workChan {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
w.prewarmPackage(ctx, pkg, workerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prewarmPackage fetches and caches a single package
|
||||
func (w *Worker) prewarmPackage(ctx context.Context, pkg PackageInfo, workerID int) {
|
||||
log.Debug().
|
||||
Int("worker", workerID).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Pre-warming package")
|
||||
|
||||
// Build URL based on registry
|
||||
url := w.buildPackageURL(pkg)
|
||||
if url == "" {
|
||||
log.Warn().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Msg("Cannot build URL for registry")
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch package from upstream
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
body, statusCode, err := w.client.Get(reqCtx, url, nil)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("package", pkg.Name).
|
||||
Msg("Failed to fetch package for pre-warming")
|
||||
return
|
||||
}
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if statusCode != 200 {
|
||||
log.Warn().
|
||||
Int("status", statusCode).
|
||||
Str("package", pkg.Name).
|
||||
Msg("Non-200 response for package")
|
||||
return
|
||||
}
|
||||
|
||||
// Cache the package
|
||||
// In a real implementation, this would read the response body and store it
|
||||
log.Info().
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Successfully pre-warmed package")
|
||||
}
|
||||
|
||||
// buildPackageURL builds the upstream URL for a package
|
||||
func (w *Worker) buildPackageURL(pkg PackageInfo) string {
|
||||
// This is simplified - in reality, each registry has different URL patterns
|
||||
switch pkg.Registry {
|
||||
case "npm":
|
||||
return "https://registry.npmjs.org/" + pkg.Name
|
||||
case "pypi":
|
||||
return "https://pypi.org/simple/" + pkg.Name + "/"
|
||||
case "go":
|
||||
// Go modules use different URL patterns
|
||||
return "https://proxy.golang.org/" + pkg.Name + "/@latest"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// combinePackages merges popular and trending packages, removing duplicates
|
||||
func (w *Worker) combinePackages(popular, trending []analytics.PopularPackage) []analytics.PopularPackage {
|
||||
seen := make(map[string]bool)
|
||||
result := make([]analytics.PopularPackage, 0, len(popular)+len(trending))
|
||||
|
||||
for _, pkg := range popular {
|
||||
key := pkg.Registry + ":" + pkg.Name
|
||||
if !seen[key] {
|
||||
result = append(result, pkg)
|
||||
seen[key] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, pkg := range trending {
|
||||
key := pkg.Registry + ":" + pkg.Name
|
||||
if !seen[key] {
|
||||
result = append(result, pkg)
|
||||
seen[key] = true
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Stop gracefully stops the pre-warming worker
|
||||
func (w *Worker) Stop() {
|
||||
if !w.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Msg("Stopping pre-warming worker")
|
||||
close(w.stopChan)
|
||||
w.wg.Wait()
|
||||
log.Info().Msg("Pre-warming worker stopped")
|
||||
}
|
||||
|
||||
// TriggerPrewarm manually triggers a pre-warming cycle
|
||||
func (w *Worker) TriggerPrewarm(ctx context.Context) {
|
||||
if !w.enabled {
|
||||
log.Warn().Msg("Cannot trigger pre-warm: worker is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().Msg("Manual pre-warming triggered")
|
||||
go w.prewarmPopularPackages(ctx)
|
||||
}
|
||||
|
||||
// PrewarmPackage pre-warms a specific package
|
||||
func (w *Worker) PrewarmPackage(ctx context.Context, registry, name, version string) error {
|
||||
if !w.enabled {
|
||||
log.Warn().Msg("Pre-warming worker is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
pkg := PackageInfo{
|
||||
Registry: registry,
|
||||
Name: name,
|
||||
Version: version,
|
||||
Priority: 100,
|
||||
}
|
||||
|
||||
w.prewarmPackage(ctx, pkg, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the pre-warming worker
|
||||
func (w *Worker) GetStatus() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"enabled": w.enabled,
|
||||
"interval": w.interval.String(),
|
||||
"max_concurrent": w.maxConcurrent,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
)
|
||||
|
||||
// BaseHandler provides common functionality for all proxy handlers
|
||||
type BaseHandler struct {
|
||||
Cache *cache.Manager
|
||||
Client *network.Client
|
||||
Upstream string
|
||||
Registry string
|
||||
}
|
||||
|
||||
// Config holds common proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream registry URL (e.g., registry.npmjs.org)
|
||||
}
|
||||
|
||||
// GetRegistry returns the registry type
|
||||
func (h *BaseHandler) GetRegistry() string {
|
||||
return h.Registry
|
||||
}
|
||||
|
||||
// NewBaseHandler creates a new base handler with common fields
|
||||
func NewBaseHandler(cache *cache.Manager, client *network.Client, registry, upstream string) *BaseHandler {
|
||||
return &BaseHandler{
|
||||
Cache: cache,
|
||||
Client: client,
|
||||
Upstream: upstream,
|
||||
Registry: registry,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewBaseHandler tests base handler creation
|
||||
func TestNewBaseHandler(t *testing.T) {
|
||||
// Use nil for cache and client since we're only testing structure
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
require.NotNil(t, handler)
|
||||
assert.Equal(t, "npm", handler.Registry)
|
||||
assert.Equal(t, "https://registry.npmjs.org", handler.Upstream)
|
||||
assert.Nil(t, handler.Cache)
|
||||
assert.Nil(t, handler.Client)
|
||||
}
|
||||
|
||||
// TestGetRegistry tests registry type retrieval
|
||||
func TestGetRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
}{
|
||||
{"npm registry", "npm"},
|
||||
{"pypi registry", "pypi"},
|
||||
{"go registry", "go"},
|
||||
{"custom registry", "custom"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := &BaseHandler{Registry: tt.registry}
|
||||
assert.Equal(t, tt.registry, handler.GetRegistry())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError tests upstream error handling
|
||||
func TestHandleUpstreamError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
url string
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
// GOOD: Standard error
|
||||
{
|
||||
name: "connection error",
|
||||
err: errors.New("connection refused"),
|
||||
url: "https://registry.npmjs.org/react",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
// WRONG: Timeout error
|
||||
{
|
||||
name: "timeout error",
|
||||
err: context.DeadlineExceeded,
|
||||
url: "https://registry.npmjs.org/lodash",
|
||||
context: "metadata",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch metadata",
|
||||
},
|
||||
// EDGE: Empty context
|
||||
{
|
||||
name: "empty context",
|
||||
err: errors.New("error"),
|
||||
url: "https://example.com",
|
||||
context: "",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch",
|
||||
},
|
||||
// EDGE: Long URL
|
||||
{
|
||||
name: "long URL",
|
||||
err: errors.New("error"),
|
||||
url: "https://registry.npmjs.org/@scope/very-long-package-name/versions/1.2.3",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleUpstreamError(w, tt.err, tt.url, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUpstreamStatus tests upstream status validation
|
||||
func TestCheckUpstreamStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body io.ReadCloser
|
||||
wantErr bool
|
||||
errContains string
|
||||
bodyClosed bool
|
||||
}{
|
||||
// GOOD: OK status
|
||||
{
|
||||
name: "200 OK",
|
||||
statusCode: http.StatusOK,
|
||||
body: io.NopCloser(strings.NewReader("success")),
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Not found
|
||||
{
|
||||
name: "404 Not Found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: io.NopCloser(strings.NewReader("not found")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// WRONG: Server error
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: io.NopCloser(strings.NewReader("error")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 500",
|
||||
},
|
||||
// BAD: Unauthorized
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
body: io.NopCloser(strings.NewReader("unauthorized")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 401",
|
||||
},
|
||||
// EDGE: Nil body
|
||||
{
|
||||
name: "nil body with error",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: nil,
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// EDGE: Redirect status
|
||||
{
|
||||
name: "302 Found",
|
||||
statusCode: http.StatusFound,
|
||||
body: io.NopCloser(strings.NewReader("redirect")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 302",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckUpstreamStatus(tt.statusCode, tt.body)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInvalidRequest tests invalid request handling
|
||||
func TestHandleInvalidRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "npm invalid request",
|
||||
registry: "npm",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid npm request",
|
||||
},
|
||||
{
|
||||
name: "pypi invalid request",
|
||||
registry: "pypi",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid pypi request",
|
||||
},
|
||||
{
|
||||
name: "go invalid request",
|
||||
registry: "go",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid go request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInvalidRequest(w, tt.registry)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInternalError tests internal error handling
|
||||
func TestHandleInternalError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "database error",
|
||||
err: errors.New("database connection failed"),
|
||||
context: "database",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: database",
|
||||
},
|
||||
{
|
||||
name: "cache error",
|
||||
err: errors.New("cache write failed"),
|
||||
context: "cache",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInternalError(w, tt.err, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: FetchFromUpstream tests would require mocking cache.Manager and network.Client
|
||||
// which requires concrete implementations. Integration tests cover this functionality.
|
||||
|
||||
// TestWriteResponse tests HTTP response writing
|
||||
func TestWriteResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
contentType string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Write tarball
|
||||
{
|
||||
name: "write tarball",
|
||||
data: "package data here",
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "package data here",
|
||||
wantErr: false,
|
||||
},
|
||||
// GOOD: Write JSON
|
||||
{
|
||||
name: "write JSON metadata",
|
||||
data: `{"name":"react","version":"18.2.0"}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: `{"name":"react","version":"18.2.0"}`,
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Empty data
|
||||
{
|
||||
name: "empty data",
|
||||
data: "",
|
||||
contentType: "text/plain",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Large data
|
||||
{
|
||||
name: "large data",
|
||||
data: strings.Repeat("x", 100000),
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: strings.Repeat("x", 100000),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(tt.data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, tt.contentType)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.contentType, w.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseHandlerFields tests that BaseHandler fields are properly set
|
||||
func TestBaseHandlerFields(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
expected interface{}
|
||||
}{
|
||||
{"registry field", "registry", "npm"},
|
||||
{"upstream field", "upstream", "https://registry.npmjs.org"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
switch tt.field {
|
||||
case "registry":
|
||||
assert.Equal(t, tt.expected, handler.Registry)
|
||||
case "upstream":
|
||||
assert.Equal(t, tt.expected, handler.Upstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyHandlerInterface tests that BaseHandler can be used as ProxyHandler
|
||||
func TestProxyHandlerInterface(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
// Verify GetRegistry works
|
||||
registry := handler.GetRegistry()
|
||||
assert.Equal(t, "npm", registry)
|
||||
}
|
||||
|
||||
// TestConcurrentWriteResponse tests that WriteResponse is safe for concurrent use
|
||||
func TestConcurrentWriteResponse(t *testing.T) {
|
||||
const numGoroutines = 10
|
||||
|
||||
errs := make(chan error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(n int) {
|
||||
w := httptest.NewRecorder()
|
||||
data := strings.Repeat("x", 1000)
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, "text/plain")
|
||||
errs <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errs
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// HandleUpstreamError logs an error and sends an HTTP 502 Bad Gateway response
|
||||
// This is the common pattern used across all proxy handlers when upstream fetch fails
|
||||
func HandleUpstreamError(w http.ResponseWriter, err error, url, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", url).
|
||||
Str("context", context).
|
||||
Msg("Failed to fetch from upstream")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Failed to fetch %s", context), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// CheckUpstreamStatus validates HTTP status code from upstream
|
||||
// Returns error if status is not OK, closing body if needed
|
||||
func CheckUpstreamStatus(statusCode int, body io.ReadCloser) error {
|
||||
if statusCode != http.StatusOK {
|
||||
if body != nil {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
return fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleInvalidRequest sends a 400 Bad Request response for invalid proxy requests
|
||||
func HandleInvalidRequest(w http.ResponseWriter, registry string) {
|
||||
http.Error(w, fmt.Sprintf("Invalid %s request", registry), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// HandleInternalError logs an internal error and sends 500 response
|
||||
func HandleInternalError(w http.ResponseWriter, err error, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("context", context).
|
||||
Msg("Internal error processing request")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Internal error: %s", context), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FetchFromUpstream is a common helper to fetch content from upstream with caching
|
||||
// This encapsulates the common pattern of: cache.Get -> network.Get -> error handling
|
||||
func FetchFromUpstream(
|
||||
ctx context.Context,
|
||||
cacheManager *cache.Manager,
|
||||
client *network.Client,
|
||||
registry, name, version, upstreamURL string,
|
||||
) (*cache.CacheEntry, error) {
|
||||
entry, err := cacheManager.Get(ctx, registry, name, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := client.Get(ctx, upstreamURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := CheckUpstreamStatus(statusCode, body); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return body, upstreamURL, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", upstreamURL).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to fetch package from upstream")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// WriteResponse writes the cache entry data to the HTTP response writer
|
||||
// Sets appropriate content type and handles errors
|
||||
func WriteResponse(w http.ResponseWriter, entry *cache.CacheEntry, contentType string) error {
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if _, err := io.Copy(w, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to write response")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProxyHandler defines the common interface for all registry proxies
|
||||
type ProxyHandler interface {
|
||||
http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// GetRegistry returns the registry type (npm, pypi, go)
|
||||
GetRegistry() string
|
||||
|
||||
// Health checks if the proxy can reach its upstream
|
||||
Health(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Stats represents proxy statistics
|
||||
type Stats struct {
|
||||
Registry string
|
||||
TotalRequests int64
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
UpstreamErrors int64
|
||||
AvgResponseTime time.Duration
|
||||
LastUpdated time.Time
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/vcs"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the GOPROXY protocol
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
sumDBURL string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.GoValidator
|
||||
validationCache *auth.ValidationCache
|
||||
gitFetcher *vcs.GitFetcher
|
||||
moduleBuilder *vcs.ModuleBuilder
|
||||
}
|
||||
|
||||
// Config holds Go proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream Go proxy (e.g., proxy.golang.org)
|
||||
SumDBURL string // Checksum database URL
|
||||
CredStore *vcs.CredentialStore // Optional credential store for git access
|
||||
}
|
||||
|
||||
// New creates a new Go proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://proxy.golang.org"
|
||||
}
|
||||
|
||||
if config.SumDBURL == "" {
|
||||
config.SumDBURL = "https://sum.golang.org"
|
||||
}
|
||||
|
||||
// Use provided credential store or create empty one
|
||||
credStore := config.CredStore
|
||||
if credStore == nil {
|
||||
credStore = vcs.NewCredentialStore()
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
sumDBURL: config.SumDBURL,
|
||||
credExtractor: auth.NewCredentialExtractor(),
|
||||
credHasher: auth.NewCredentialHasher(),
|
||||
credValidator: auth.NewGoValidator(),
|
||||
validationCache: auth.NewValidationCache(5 * time.Minute),
|
||||
gitFetcher: vcs.NewGitFetcher("", credStore),
|
||||
moduleBuilder: vcs.NewModuleBuilder(),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles GOPROXY protocol requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// Path is already stripped by http.StripPrefix in app.go
|
||||
path := r.URL.Path
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Msg("Processing Go proxy request")
|
||||
|
||||
// Parse GOPROXY request
|
||||
// Formats:
|
||||
// /@v/list - list versions
|
||||
// /@v/$version.info - version info
|
||||
// /@v/$version.mod - go.mod file
|
||||
// /@v/$version.zip - module zip
|
||||
// /@latest - latest version
|
||||
|
||||
log.Debug().Str("path", path).Msg("Go proxy request")
|
||||
|
||||
// Route request based on path
|
||||
if strings.HasPrefix(path, "/sumdb/") {
|
||||
h.handleSumDB(ctx, w, r, path)
|
||||
} else if strings.HasSuffix(path, "/@v/list") {
|
||||
h.handleList(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".info") {
|
||||
h.handleInfo(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".mod") {
|
||||
h.handleMod(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".zip") {
|
||||
h.handleZip(ctx, w, r, path)
|
||||
} else if strings.HasSuffix(path, "/@latest") {
|
||||
h.handleLatest(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid Go proxy request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleList handles /@v/list requests
|
||||
func (h *Handler) handleList(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", modulePath, "list", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version list")
|
||||
http.Error(w, "Failed to fetch version list", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleInfo handles /@v/$version.info requests
|
||||
func (h *Handler) handleInfo(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".info")
|
||||
// Use .info suffix to distinguish from .mod and .zip in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".info"
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version info")
|
||||
http.Error(w, "Failed to fetch version info", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleMod handles /@v/$version.mod requests
|
||||
func (h *Handler) handleMod(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".mod")
|
||||
// Use .mod suffix to distinguish from .info and .zip in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".mod"
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch go.mod")
|
||||
http.Error(w, "Failed to fetch go.mod", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleZip handles /@v/$version.zip requests
|
||||
func (h *Handler) handleZip(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".zip")
|
||||
// Use .zip suffix to distinguish from .info and .mod in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".zip"
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
credHash := h.credHasher.Hash(credentials)
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("url", url).
|
||||
Str("cred_hash", credHash).
|
||||
Bool("has_credentials", credentials != "").
|
||||
Msg("Handling Go module zip request")
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
// Try upstream proxy first (fast path for public modules)
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
return body, url, nil
|
||||
}
|
||||
|
||||
// If upstream failed with 404 or 403, try git fallback (private modules)
|
||||
if statusCode == http.StatusNotFound || statusCode == http.StatusForbidden {
|
||||
if body != nil {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Int("upstream_status", statusCode).
|
||||
Msg("Upstream proxy returned not found, trying git fallback")
|
||||
|
||||
return h.fetchModuleFromGit(ctx, modulePath, version, credentials)
|
||||
}
|
||||
|
||||
// Other errors
|
||||
if body != nil {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch module zip")
|
||||
|
||||
// Check if error is a security violation - return 403 Forbidden
|
||||
if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation {
|
||||
http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// All other errors return 502 Bad Gateway (upstream issues)
|
||||
http.Error(w, "Failed to fetch module zip", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// CRITICAL SECURITY CHECK: If module requires auth, validate credentials
|
||||
if entry.Package != nil && entry.Package.RequiresAuth {
|
||||
// Check validation cache first
|
||||
allowed, cached, reason := h.validationCache.Get(credHash, modulePath)
|
||||
if cached {
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("reason", reason).
|
||||
Msg("Access denied (cached validation)")
|
||||
http.Error(w, "Module not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Msg("Access granted (cached validation)")
|
||||
} else {
|
||||
// Validate with upstream using git ls-remote
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("provider", entry.Package.AuthProvider).
|
||||
Msg("Validating credentials with upstream")
|
||||
|
||||
allowed, err := h.credValidator.ValidateAccess(ctx, modulePath, credentials)
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
}
|
||||
|
||||
// Cache validation result
|
||||
h.validationCache.Set(credHash, modulePath, allowed, reason)
|
||||
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Err(err).
|
||||
Msg("Access denied by upstream")
|
||||
// Return 404 (same as GitHub does for private repos)
|
||||
http.Error(w, "Module not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Msg("Access granted by upstream")
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleLatest handles /@latest requests
|
||||
func (h *Handler) handleLatest(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", modulePath, "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch latest version")
|
||||
http.Error(w, "Failed to fetch latest version", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleSumDB handles sumdb requests (checksum database)
|
||||
func (h *Handler) handleSumDB(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
// path format: /sumdb/sum.golang.org/...
|
||||
// Remove /sumdb/ prefix and proxy to sumdb URL
|
||||
sumdbPath := strings.TrimPrefix(path, "/sumdb/sum.golang.org")
|
||||
url := h.sumDBURL + sumdbPath
|
||||
|
||||
log.Debug().Str("url", url).Msg("Proxying sumdb request")
|
||||
|
||||
// Sumdb requests should not be cached, proxy directly
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch from sumdb")
|
||||
http.Error(w, "Failed to fetch from sumdb", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
log.Error().Int("status", statusCode).Str("url", url).Msg("Sumdb returned non-OK status")
|
||||
http.Error(w, "Sumdb error", statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
_, _ = io.Copy(w, body) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// extractVersion extracts version from path
|
||||
func (h *Handler) extractVersion(path, suffix string) string {
|
||||
// path format: /module/path/@v/v1.2.3.suffix
|
||||
parts := strings.Split(path, "/@v/")
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSuffix(parts[1], suffix)
|
||||
}
|
||||
|
||||
// extractModulePath extracts the clean module path from a GOPROXY path
|
||||
// Examples:
|
||||
//
|
||||
// /github.com/avast/retry-go/v4/@v/v4.6.1.zip -> github.com/avast/retry-go/v4
|
||||
// /golang.org/x/net/@v/v0.40.0.mod -> golang.org/x/net
|
||||
// /github.com/user/repo/@v/list -> github.com/user/repo
|
||||
func (h *Handler) extractModulePath(path string) string {
|
||||
// Remove leading slash
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
// Split on /@v/ to get the module path
|
||||
parts := strings.Split(path, "/@v/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// Fallback: remove /@latest suffix if present
|
||||
return strings.TrimSuffix(path, "/@latest")
|
||||
}
|
||||
|
||||
// fetchModuleFromGit fetches a Go module directly from git repository
|
||||
func (h *Handler) fetchModuleFromGit(ctx context.Context, modulePath, version, credentials string) (io.ReadCloser, string, error) {
|
||||
log.Info().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Msg("Fetching module from git repository")
|
||||
|
||||
// 1. Fetch module source from git
|
||||
srcPath, err := h.gitFetcher.FetchModule(ctx, modulePath, version, credentials)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("git fetch failed: %w", err)
|
||||
}
|
||||
defer h.gitFetcher.Cleanup(srcPath)
|
||||
|
||||
// 2. Validate module
|
||||
if err := h.moduleBuilder.ValidateModule(ctx, srcPath, modulePath); err != nil {
|
||||
return nil, "", fmt.Errorf("module validation failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Build module zip
|
||||
zipReader, err := h.moduleBuilder.BuildModuleZip(ctx, srcPath, modulePath, version)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("module zip build failed: %w", err)
|
||||
}
|
||||
|
||||
// Create source URL for logging
|
||||
sourceURL := fmt.Sprintf("git+https://%s@%s", modulePath, version)
|
||||
|
||||
log.Info().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("source", sourceURL).
|
||||
Msg("Successfully built module from git")
|
||||
|
||||
return zipReader, sourceURL, nil
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package npm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the NPM registry protocol
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.NPMValidator
|
||||
validationCache *auth.ValidationCache
|
||||
}
|
||||
|
||||
// Config holds NPM proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream NPM registry (e.g., registry.npmjs.org)
|
||||
}
|
||||
|
||||
// New creates a new NPM proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://registry.npmjs.org"
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
credExtractor: auth.NewCredentialExtractor(),
|
||||
credHasher: auth.NewCredentialHasher(),
|
||||
credValidator: auth.NewNPMValidator(),
|
||||
validationCache: auth.NewValidationCache(5 * time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles NPM registry requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
path := strings.TrimPrefix(r.URL.Path, "/npm")
|
||||
|
||||
log.Debug().Str("path", path).Str("method", r.Method).Msg("NPM proxy request")
|
||||
|
||||
// Handle different NPM request types
|
||||
// Check for tarballs FIRST before special endpoints (tarballs also contain "/-/")
|
||||
if isTarballRequest(path) {
|
||||
// Package tarball: /@scope/package/-/package-version.tgz
|
||||
h.handleTarball(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/-/") {
|
||||
// Special NPM endpoints (e.g., /-/ping, /-/user/token)
|
||||
h.handleSpecial(ctx, w, r, path)
|
||||
} else if isPackageMetadata(path) {
|
||||
// Package metadata: /@scope/package or /package
|
||||
h.handleMetadata(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid NPM request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMetadata handles package metadata requests
|
||||
func (h *Handler) handleMetadata(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "npm", packageName, "metadata", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package metadata")
|
||||
http.Error(w, "Failed to fetch package metadata", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Read metadata into memory for URL rewriting
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read metadata")
|
||||
http.Error(w, "Failed to read metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON metadata
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal(buf.Bytes(), &metadata); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to parse metadata JSON")
|
||||
http.Error(w, "Failed to parse metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite tarball URLs to point to our proxy
|
||||
proxyBaseURL := getProxyBaseURL(r)
|
||||
rewriteMetadataURLs(metadata, h.upstream, proxyBaseURL)
|
||||
|
||||
// Serialize modified metadata
|
||||
modifiedJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to serialize modified metadata")
|
||||
http.Error(w, "Failed to serialize metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = w.Write(modifiedJSON) // #nosec G104 -- Websocket buffer write
|
||||
}
|
||||
|
||||
// handleTarball handles package tarball requests
|
||||
func (h *Handler) handleTarball(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
packageName, version := extractTarballInfo(path)
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
credHash := h.credHasher.Hash(credentials)
|
||||
|
||||
// Construct proper upstream URL with /-/ format
|
||||
// Format: https://registry.npmjs.org/package/-/package-version.tgz
|
||||
tarballFilename := strings.ReplaceAll(packageName, "/", "-") + "-" + version + ".tgz"
|
||||
url := fmt.Sprintf("%s/%s/-/%s", h.upstream, packageName, tarballFilename)
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("upstream_url", url).
|
||||
Str("cred_hash", credHash).
|
||||
Bool("has_credentials", credentials != "").
|
||||
Msg("Handling tarball request")
|
||||
|
||||
// Try to get from cache first (with credential-aware key)
|
||||
entry, err := h.cache.Get(ctx, "npm", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, url, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package tarball")
|
||||
|
||||
// Check if error is a security violation - return 403 Forbidden
|
||||
if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation {
|
||||
http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// All other errors return 502 Bad Gateway (upstream issues)
|
||||
http.Error(w, "Failed to fetch package tarball", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// CRITICAL SECURITY CHECK: If package requires auth, validate credentials
|
||||
if entry.Package != nil && entry.Package.RequiresAuth {
|
||||
// Check validation cache first
|
||||
allowed, cached, reason := h.validationCache.Get(credHash, url)
|
||||
if cached {
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("reason", reason).
|
||||
Msg("Access denied (cached validation)")
|
||||
http.Error(w, "Access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Access granted (cached validation)")
|
||||
} else {
|
||||
// Validate with upstream
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("provider", entry.Package.AuthProvider).
|
||||
Msg("Validating credentials with upstream")
|
||||
|
||||
allowed, err := h.credValidator.ValidateAccess(ctx, url, credentials)
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
}
|
||||
|
||||
// Cache validation result
|
||||
h.validationCache.Set(credHash, url, allowed, reason)
|
||||
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Err(err).
|
||||
Msg("Access denied by upstream")
|
||||
http.Error(w, "Access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Access granted by upstream")
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handleSpecial handles special NPM endpoints
|
||||
func (h *Handler) handleSpecial(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
|
||||
// Don't cache special endpoints, proxy directly
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch special endpoint")
|
||||
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
_, _ = io.Copy(w, body) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// isTarballRequest checks if the request is for a tarball
|
||||
func isTarballRequest(path string) bool {
|
||||
return strings.HasSuffix(path, ".tgz") || strings.HasSuffix(path, ".tar.gz")
|
||||
}
|
||||
|
||||
// isPackageMetadata checks if the request is for package metadata
|
||||
func isPackageMetadata(path string) bool {
|
||||
// Package metadata doesn't have file extensions
|
||||
return !isTarballRequest(path) && !strings.Contains(path, "/-/")
|
||||
}
|
||||
|
||||
// extractPackageName extracts package name from path
|
||||
func extractPackageName(path string) string {
|
||||
// Remove leading slash
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
// Handle scoped packages (@scope/package)
|
||||
if strings.HasPrefix(path, "@") {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0] + "/" + parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// Regular package
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// extractTarballInfo extracts package name and version from tarball path
|
||||
func extractTarballInfo(path string) (string, string) {
|
||||
// Format: /@scope/package/-/package-version.tgz
|
||||
// or: /package/-/package-version.tgz
|
||||
// Also handle: /package/package-version.tgz (fallback)
|
||||
|
||||
// Try standard format with /-/
|
||||
parts := strings.Split(path, "/-/")
|
||||
if len(parts) == 2 {
|
||||
packageName := extractPackageName(parts[0])
|
||||
tarballName := parts[1]
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tgz")
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
|
||||
|
||||
// Remove package name prefix to get version
|
||||
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
|
||||
version := strings.TrimPrefix(tarballName, prefix)
|
||||
|
||||
return packageName, version
|
||||
}
|
||||
|
||||
// Fallback: parse path without /-/
|
||||
// Format: /package/package-version.tgz or /@scope/package/package-version.tgz
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
pathParts := strings.Split(path, "/")
|
||||
|
||||
if len(pathParts) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
var packageName, tarballName string
|
||||
|
||||
// Handle scoped packages
|
||||
if strings.HasPrefix(pathParts[0], "@") && len(pathParts) >= 3 {
|
||||
packageName = pathParts[0] + "/" + pathParts[1]
|
||||
tarballName = pathParts[len(pathParts)-1]
|
||||
} else {
|
||||
packageName = pathParts[0]
|
||||
tarballName = pathParts[len(pathParts)-1]
|
||||
}
|
||||
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tgz")
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
|
||||
|
||||
// Remove package name prefix to get version
|
||||
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
|
||||
version := strings.TrimPrefix(tarballName, prefix)
|
||||
|
||||
return packageName, version
|
||||
}
|
||||
|
||||
// getProxyBaseURL constructs the proxy base URL from the request
|
||||
func getProxyBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
host := r.Host
|
||||
return fmt.Sprintf("%s://%s/npm", scheme, host)
|
||||
}
|
||||
|
||||
// rewriteMetadataURLs recursively rewrites upstream URLs to proxy URLs in metadata
|
||||
func rewriteMetadataURLs(data interface{}, upstream, proxyBaseURL string) {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
if key == "tarball" || key == "dist" {
|
||||
// Rewrite tarball URL
|
||||
if strVal, ok := value.(string); ok {
|
||||
v[key] = strings.Replace(strVal, upstream, proxyBaseURL, 1)
|
||||
} else if distMap, ok := value.(map[string]interface{}); ok {
|
||||
// Handle dist object with tarball field
|
||||
rewriteMetadataURLs(distMap, upstream, proxyBaseURL)
|
||||
}
|
||||
} else {
|
||||
// Recursively process nested objects
|
||||
rewriteMetadataURLs(value, upstream, proxyBaseURL)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
rewriteMetadataURLs(item, upstream, proxyBaseURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,398 @@
|
||||
package pypi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the PyPI Simple API (PEP 503)
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.PyPIValidator
|
||||
validationCache *auth.ValidationCache
|
||||
}
|
||||
|
||||
// Config holds PyPI proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream PyPI index (e.g., pypi.org/simple)
|
||||
}
|
||||
|
||||
// New creates a new PyPI proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://pypi.org/simple"
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
credExtractor: auth.NewCredentialExtractor(),
|
||||
credHasher: auth.NewCredentialHasher(),
|
||||
credValidator: auth.NewPyPIValidator(),
|
||||
validationCache: auth.NewValidationCache(5 * time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles PyPI Simple API requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
path := strings.TrimPrefix(r.URL.Path, "/pypi")
|
||||
// Also trim /simple prefix since upstream already includes it
|
||||
path = strings.TrimPrefix(path, "/simple")
|
||||
|
||||
log.Debug().Str("path", path).Str("method", r.Method).Msg("PyPI proxy request")
|
||||
|
||||
// PEP 503 Simple API endpoints:
|
||||
// / - index page
|
||||
// /{package}/ - package page with links to files
|
||||
|
||||
if path == "/" || path == "" {
|
||||
// Index page
|
||||
h.handleIndex(ctx, w, r)
|
||||
} else if isPackagePage(path) {
|
||||
// Package page
|
||||
h.handlePackagePage(ctx, w, r, path)
|
||||
} else if isPackageFile(path) {
|
||||
// Package file download (wheel or sdist)
|
||||
h.handlePackageFile(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid PyPI request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleIndex handles the index page request
|
||||
func (h *Handler) handleIndex(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
url := h.upstream + "/"
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", "index", "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch PyPI index")
|
||||
http.Error(w, "Failed to fetch PyPI index", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// handlePackagePage handles package page requests
|
||||
func (h *Handler) handlePackagePage(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, "page", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package page")
|
||||
http.Error(w, "Failed to fetch package page", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Read page into memory for URL rewriting
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read package page")
|
||||
http.Error(w, "Failed to read package page", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite package file URLs to point to our proxy
|
||||
proxyBaseURL := getProxyBaseURL(r)
|
||||
modifiedHTML := rewritePackagePageURLs(buf.String(), packageName, proxyBaseURL)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
|
||||
_, _ = w.Write([]byte(modifiedHTML)) // #nosec G104 -- Websocket buffer write
|
||||
}
|
||||
|
||||
// handlePackageFile handles package file download requests
|
||||
func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
packageName, version := extractPackageFileInfo(path)
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
credHash := h.credHasher.Hash(credentials)
|
||||
|
||||
// Check if we have the original URL from the rewritten package page
|
||||
originalURL := r.URL.Query().Get("original_url")
|
||||
|
||||
// If no original URL provided, fall back to constructing from upstream
|
||||
// (this handles direct file requests not from rewritten package pages)
|
||||
if originalURL == "" {
|
||||
originalURL = h.upstream + path
|
||||
} else {
|
||||
// Make the URL absolute if it's relative
|
||||
if !strings.HasPrefix(originalURL, "http://") && !strings.HasPrefix(originalURL, "https://") {
|
||||
originalURL = "https://pypi.org" + originalURL
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("url", originalURL).
|
||||
Str("cred_hash", credHash).
|
||||
Bool("has_credentials", credentials != "").
|
||||
Msg("Handling PyPI package file request")
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
headers["Authorization"] = credentials
|
||||
}
|
||||
|
||||
body, statusCode, err := h.client.Get(ctx, originalURL, headers)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, originalURL, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", originalURL).Msg("Failed to fetch package file")
|
||||
|
||||
// Check if error is a security violation - return 403 Forbidden
|
||||
if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation {
|
||||
http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// All other errors return 502 Bad Gateway (upstream issues)
|
||||
http.Error(w, "Failed to fetch package file", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// CRITICAL SECURITY CHECK: If package requires auth, validate credentials
|
||||
if entry.Package != nil && entry.Package.RequiresAuth {
|
||||
// Check validation cache first
|
||||
allowed, cached, reason := h.validationCache.Get(credHash, originalURL)
|
||||
if cached {
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("reason", reason).
|
||||
Msg("Access denied (cached validation)")
|
||||
http.Error(w, "Access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Access granted (cached validation)")
|
||||
} else {
|
||||
// Validate with upstream
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("provider", entry.Package.AuthProvider).
|
||||
Msg("Validating credentials with upstream")
|
||||
|
||||
allowed, err := h.credValidator.ValidateAccess(ctx, originalURL, credentials)
|
||||
if err != nil {
|
||||
reason = err.Error()
|
||||
}
|
||||
|
||||
// Cache validation result
|
||||
h.validationCache.Set(credHash, originalURL, allowed, reason)
|
||||
|
||||
if !allowed {
|
||||
log.Warn().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Err(err).
|
||||
Msg("Access denied by upstream")
|
||||
http.Error(w, "Access denied", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Access granted by upstream")
|
||||
}
|
||||
}
|
||||
|
||||
// Determine content type based on file extension
|
||||
contentType := "application/octet-stream"
|
||||
if strings.HasSuffix(path, ".whl") {
|
||||
contentType = "application/zip"
|
||||
} else if strings.HasSuffix(path, ".tar.gz") {
|
||||
contentType = "application/x-gzip"
|
||||
} else if strings.HasSuffix(path, ".metadata") {
|
||||
contentType = "text/plain; charset=UTF-8"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
_, _ = io.Copy(w, entry.Data) // #nosec G104 -- HTTP response write
|
||||
}
|
||||
|
||||
// isPackagePage checks if the request is for a package page
|
||||
func isPackagePage(path string) bool {
|
||||
// Package pages end with /
|
||||
return strings.HasSuffix(path, "/")
|
||||
}
|
||||
|
||||
// isPackageFile checks if the request is for a package file
|
||||
func isPackageFile(path string) bool {
|
||||
// Package files (not including .metadata files which need special handling)
|
||||
return strings.HasSuffix(path, ".whl") ||
|
||||
strings.HasSuffix(path, ".tar.gz") ||
|
||||
strings.HasSuffix(path, ".zip") ||
|
||||
strings.HasSuffix(path, ".egg")
|
||||
}
|
||||
|
||||
// extractPackageName extracts package name from path
|
||||
func extractPackageName(path string) string {
|
||||
// Remove leading and trailing slashes
|
||||
path = strings.Trim(path, "/")
|
||||
|
||||
// Remove /simple/ prefix if present
|
||||
path = strings.TrimPrefix(path, "simple/")
|
||||
|
||||
// For package pages: /package-name/
|
||||
// For files: /package-name/package-name-version.whl
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// extractPackageFileInfo extracts package name and version from file path
|
||||
func extractPackageFileInfo(path string) (string, string) {
|
||||
// Format: /package-name/package-name-version.whl
|
||||
// or: /package-name/package-name-version.tar.gz
|
||||
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
// Extract filename
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 2 {
|
||||
return packageName, ""
|
||||
}
|
||||
|
||||
filename := parts[len(parts)-1]
|
||||
|
||||
// Remove extension
|
||||
filename = strings.TrimSuffix(filename, ".whl")
|
||||
filename = strings.TrimSuffix(filename, ".tar.gz")
|
||||
filename = strings.TrimSuffix(filename, ".zip")
|
||||
filename = strings.TrimSuffix(filename, ".egg")
|
||||
|
||||
// Extract version
|
||||
// Filename format: package-name-version or package_name-version
|
||||
// Version typically starts after last dash before build tags
|
||||
versionParts := strings.Split(filename, "-")
|
||||
if len(versionParts) >= 2 {
|
||||
// Simple heuristic: version is the part that starts with a digit
|
||||
for i := 1; i < len(versionParts); i++ {
|
||||
if len(versionParts[i]) > 0 && versionParts[i][0] >= '0' && versionParts[i][0] <= '9' {
|
||||
return packageName, versionParts[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return packageName, filename
|
||||
}
|
||||
|
||||
// getProxyBaseURL constructs the proxy base URL from the request
|
||||
func getProxyBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
host := r.Host
|
||||
return fmt.Sprintf("%s://%s/pypi", scheme, host)
|
||||
}
|
||||
|
||||
// rewritePackagePageURLs rewrites package file URLs in HTML to point to proxy
|
||||
func rewritePackagePageURLs(html, packageName, proxyBaseURL string) string {
|
||||
// PyPI Simple API uses href attributes in anchor tags
|
||||
// We need to rewrite URLs pointing to files.pythonhosted.org or pypi.org
|
||||
// We preserve the original URL as a query parameter so we can fetch from the correct CDN
|
||||
|
||||
// Regex pattern to match href URLs pointing to package files
|
||||
// Matches: href="https://files.pythonhosted.org/packages/.../filename.whl"
|
||||
// Also matches: href="../../packages/.../filename.whl"
|
||||
pattern := regexp.MustCompile(`href="([^"]*?(\.whl|\.tar\.gz|\.zip|\.egg)[^"]*?)"`)
|
||||
|
||||
result := pattern.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// Extract the full URL and filename
|
||||
urlPattern := regexp.MustCompile(`href="([^"]+)"`)
|
||||
urlMatch := urlPattern.FindStringSubmatch(match)
|
||||
if len(urlMatch) < 2 {
|
||||
return match
|
||||
}
|
||||
|
||||
originalURL := urlMatch[1]
|
||||
|
||||
// Extract just the filename
|
||||
filenamePattern := regexp.MustCompile(`([^/]+\.(whl|tar\.gz|zip|egg))`)
|
||||
filenameMatch := filenamePattern.FindString(originalURL)
|
||||
|
||||
if filenameMatch != "" {
|
||||
// Rewrite to proxy URL format: /pypi/package-name/filename?original_url=...
|
||||
// This preserves the original CDN URL so we can fetch from the correct location
|
||||
baseURL := strings.TrimSuffix(proxyBaseURL, "/simple")
|
||||
|
||||
// URL encode the original URL
|
||||
encodedURL := strings.ReplaceAll(originalURL, "&", "%26")
|
||||
encodedURL = strings.ReplaceAll(encodedURL, "=", "%3D")
|
||||
|
||||
newURL := fmt.Sprintf(`href="%s/%s/%s?original_url=%s"`, baseURL, packageName, filenameMatch, encodedURL)
|
||||
return newURL
|
||||
}
|
||||
|
||||
return match
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package ghsa
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "github-advisory-database"
|
||||
|
||||
// Scanner implements the GitHub Advisory Database vulnerability scanner
|
||||
type Scanner struct {
|
||||
config config.GHSAConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// New creates a new GitHub Advisory Database scanner
|
||||
func New(cfg config.GHSAConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans a package using GitHub Advisory Database API
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("registry", registry).
|
||||
Msg("Starting GitHub Advisory Database scan")
|
||||
|
||||
// Map registry to GitHub ecosystem
|
||||
ecosystem := s.mapRegistryToEcosystem(registry)
|
||||
if ecosystem == "" {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{
|
||||
"skipped": fmt.Sprintf("GitHub Advisory Database does not support registry: %s", registry),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Query GitHub Advisory Database
|
||||
advisories, err := s.queryAdvisories(ctx, ecosystem, packageName)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to query GitHub Advisory Database")
|
||||
return s.emptyResult(registry, packageName, version), nil
|
||||
}
|
||||
|
||||
// Filter advisories that affect this version
|
||||
affectedAdvisories := s.filterAffectedAdvisories(advisories, version)
|
||||
|
||||
// Convert to our format
|
||||
result := s.convertResult(affectedAdvisories, registry, packageName, version)
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("GitHub Advisory Database scan completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Health checks if GitHub API is accessible
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/advisories", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
if s.config.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+s.config.Token)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("github advisory database not accessible: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("github api returned status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mapRegistryToEcosystem maps our registry names to GitHub ecosystem names
|
||||
func (s *Scanner) mapRegistryToEcosystem(registry string) string {
|
||||
mapping := map[string]string{
|
||||
"npm": "npm",
|
||||
"pypi": "pip",
|
||||
"go": "go",
|
||||
"maven": "maven",
|
||||
"nuget": "nuget",
|
||||
"cargo": "cargo",
|
||||
"pub": "pub",
|
||||
}
|
||||
return mapping[strings.ToLower(registry)]
|
||||
}
|
||||
|
||||
// queryAdvisories queries GitHub Advisory Database for a package
|
||||
func (s *Scanner) queryAdvisories(ctx context.Context, ecosystem, packageName string) ([]GHSAAdvisory, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/advisories?ecosystem=%s&affects=%s&per_page=100", ecosystem, packageName)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
if s.config.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+s.config.Token)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query advisories: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("github api returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var advisories []GHSAAdvisory
|
||||
if err := json.NewDecoder(resp.Body).Decode(&advisories); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return advisories, nil
|
||||
}
|
||||
|
||||
// filterAffectedAdvisories filters advisories that affect the given version
|
||||
func (s *Scanner) filterAffectedAdvisories(advisories []GHSAAdvisory, version string) []GHSAAdvisory {
|
||||
// Check if this version is affected
|
||||
// GitHub API already filters by package, but we need to check version ranges
|
||||
// For now, we'll include all advisories that match the package
|
||||
// A more sophisticated implementation would parse version ranges
|
||||
affected := append([]GHSAAdvisory(nil), advisories...)
|
||||
|
||||
return affected
|
||||
}
|
||||
|
||||
// emptyResult returns an empty scan result
|
||||
func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// convertResult converts GitHub Advisory Database results to our ScanResult format
|
||||
func (s *Scanner) convertResult(advisories []GHSAAdvisory, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
for _, advisory := range advisories {
|
||||
// Normalize severity
|
||||
normalizedSeverity := metadata.NormalizeSeverity(advisory.Severity)
|
||||
severityCounts[normalizedSeverity]++
|
||||
|
||||
// Extract references
|
||||
refs := make([]string, 0)
|
||||
if advisory.HTMLURL != "" {
|
||||
refs = append(refs, advisory.HTMLURL)
|
||||
}
|
||||
for _, ref := range advisory.References {
|
||||
if ref.URL != "" {
|
||||
refs = append(refs, ref.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// Get fixed versions
|
||||
fixedIn := ""
|
||||
for _, vuln := range advisory.Vulnerabilities {
|
||||
if vuln.FirstPatchedVersion != nil && vuln.FirstPatchedVersion.Identifier != "" {
|
||||
fixedIn = vuln.FirstPatchedVersion.Identifier
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: advisory.GHSAID,
|
||||
Severity: normalizedSeverity,
|
||||
Title: advisory.Summary,
|
||||
Description: advisory.Description,
|
||||
References: refs,
|
||||
FixedIn: fixedIn,
|
||||
})
|
||||
}
|
||||
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"severity_counts": severityCounts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GHSAAdvisory represents a GitHub Security Advisory
|
||||
type GHSAAdvisory struct {
|
||||
GHSAID string `json:"ghsa_id"`
|
||||
CVEID string `json:"cve_id"`
|
||||
Summary string `json:"summary"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
References []GHSAReference `json:"references"`
|
||||
Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
type GHSAReference struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type GHSAVulnerability struct {
|
||||
Package GHSAPackage `json:"package"`
|
||||
VulnerableVersions string `json:"vulnerable_version_range"`
|
||||
FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"`
|
||||
}
|
||||
|
||||
type GHSAPackage struct {
|
||||
Ecosystem string `json:"ecosystem"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type GHSAPatchVersion struct {
|
||||
Identifier string `json:"identifier"`
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
package govulncheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "govulncheck"
|
||||
|
||||
// Scanner implements the govulncheck vulnerability scanner for Go modules
|
||||
type Scanner struct {
|
||||
config config.GovulncheckConfig
|
||||
}
|
||||
|
||||
// New creates a new govulncheck scanner
|
||||
func New(cfg config.GovulncheckConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans a Go module using govulncheck
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
// Only scan Go packages
|
||||
if registry != "go" {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{
|
||||
"skipped": "govulncheck only supports Go modules",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Starting govulncheck scan")
|
||||
|
||||
// Create a temporary directory for extraction
|
||||
tmpDir, err := os.MkdirTemp("", "govulncheck-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Extract the .zip file
|
||||
if err := s.extractZip(filePath, tmpDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to extract zip: %w", err)
|
||||
}
|
||||
|
||||
// Run govulncheck
|
||||
cmd := exec.CommandContext(ctx, "govulncheck", "-json", "-mode=binary", tmpDir) // #nosec G204 -- govulncheck command with temp directory
|
||||
output, _ := cmd.CombinedOutput()
|
||||
|
||||
// govulncheck returns non-zero when vulnerabilities are found
|
||||
// Parse output regardless of error
|
||||
var vulns []GovulncheckVuln
|
||||
if len(output) > 0 {
|
||||
// Parse line-delimited JSON
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
var entry GovulncheckEntry
|
||||
if err := json.Unmarshal([]byte(line), &entry); err != nil {
|
||||
log.Warn().Err(err).Str("line", line).Msg("Failed to parse govulncheck line")
|
||||
continue
|
||||
}
|
||||
if entry.Finding != nil && entry.Finding.OSV != "" {
|
||||
vulns = append(vulns, GovulncheckVuln{
|
||||
OSV: entry.Finding.OSV,
|
||||
FixedVersion: entry.Finding.FixedVersion,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to our format
|
||||
result := s.convertResult(vulns, registry, packageName, version)
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("govulncheck scan completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Health checks if govulncheck is available
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "govulncheck", "-version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("govulncheck not available: %w (install with: go install golang.org/x/vuln/cmd/govulncheck@latest)", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractZip extracts a zip file to destination
|
||||
func (s *Scanner) extractZip(zipPath, destDir string) error {
|
||||
cmd := exec.Command("unzip", "-q", zipPath, "-d", destDir)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// convertResult converts govulncheck findings to our ScanResult format
|
||||
func (s *Scanner) convertResult(vulns []GovulncheckVuln, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, vuln := range vulns {
|
||||
// Deduplicate by OSV ID
|
||||
if seen[vuln.OSV] {
|
||||
continue
|
||||
}
|
||||
seen[vuln.OSV] = true
|
||||
|
||||
// govulncheck doesn't provide severity in output
|
||||
// Default to HIGH for found vulnerabilities
|
||||
severity := metadata.NormalizeSeverity("HIGH")
|
||||
severityCounts[severity]++
|
||||
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: vuln.OSV,
|
||||
Severity: severity,
|
||||
Title: vuln.OSV,
|
||||
Description: fmt.Sprintf("Vulnerability %s found by govulncheck", vuln.OSV),
|
||||
References: []string{fmt.Sprintf("https://pkg.go.dev/vuln/%s", vuln.OSV)},
|
||||
FixedIn: vuln.FixedVersion,
|
||||
})
|
||||
}
|
||||
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"severity_counts": severityCounts,
|
||||
"note": "govulncheck provides reachability analysis for Go modules",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GovulncheckEntry represents a single line of govulncheck JSON output
|
||||
type GovulncheckEntry struct {
|
||||
Finding *GovulncheckFinding `json:"finding,omitempty"`
|
||||
}
|
||||
|
||||
type GovulncheckFinding struct {
|
||||
OSV string `json:"osv"`
|
||||
FixedVersion string `json:"fixed_version,omitempty"`
|
||||
}
|
||||
|
||||
type GovulncheckVuln struct {
|
||||
OSV string
|
||||
FixedVersion string
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package grype
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "grype"
|
||||
|
||||
// Scanner implements the Grype vulnerability scanner
|
||||
type Scanner struct {
|
||||
config config.GrypeConfig
|
||||
}
|
||||
|
||||
// New creates a new Grype scanner
|
||||
func New(cfg config.GrypeConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans a package using Grype
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("file", filePath).
|
||||
Msg("Starting Grype scan")
|
||||
|
||||
// Run grype scan
|
||||
cmd := exec.CommandContext(ctx, "grype", filePath, "-o", "json", "-q")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Grype returns non-zero exit code when vulnerabilities are found
|
||||
// Only treat it as error if we got no output
|
||||
if len(output) == 0 {
|
||||
return nil, fmt.Errorf("grype scan failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Grype JSON output
|
||||
var grypeResult GrypeResult
|
||||
if err := json.Unmarshal(output, &grypeResult); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse grype output: %w", err)
|
||||
}
|
||||
|
||||
// Convert to our format
|
||||
result := s.convertGrypeResult(&grypeResult, registry, packageName, version)
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("Grype scan completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Health checks if Grype is available
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "grype", "version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("grype not available: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDatabase updates Grype's vulnerability database
|
||||
func (s *Scanner) UpdateDatabase(ctx context.Context) error {
|
||||
log.Info().Str("scanner", ScannerName).Msg("Updating Grype database")
|
||||
|
||||
cmd := exec.CommandContext(ctx, "grype", "db", "update")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update grype database: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
log.Info().Str("scanner", ScannerName).Msg("Grype database updated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertGrypeResult converts Grype output to our ScanResult format
|
||||
func (s *Scanner) convertGrypeResult(grypeResult *GrypeResult, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
// Process each vulnerability match
|
||||
for _, match := range grypeResult.Matches {
|
||||
// Normalize severity
|
||||
normalizedSeverity := metadata.NormalizeSeverity(match.Vulnerability.Severity)
|
||||
|
||||
// Count by severity
|
||||
severityCounts[normalizedSeverity]++
|
||||
|
||||
// Extract fixed version
|
||||
fixedIn := ""
|
||||
if match.Vulnerability.Fix.State == "fixed" {
|
||||
for _, version := range match.Vulnerability.Fix.Versions {
|
||||
if fixedIn == "" {
|
||||
fixedIn = version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to vulnerabilities list
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: match.Vulnerability.ID,
|
||||
Severity: normalizedSeverity,
|
||||
Title: match.Vulnerability.ID, // Grype doesn't have separate title
|
||||
Description: match.Vulnerability.Description,
|
||||
References: match.Vulnerability.URLs,
|
||||
FixedIn: fixedIn,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine overall status
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"severity_counts": severityCounts,
|
||||
"grype_version": grypeResult.Descriptor.Version,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GrypeResult represents Grype JSON output structure
|
||||
type GrypeResult struct {
|
||||
Matches []GrypeMatch `json:"matches"`
|
||||
Descriptor GrypeDescriptor `json:"descriptor"`
|
||||
Source GrypeSource `json:"source"`
|
||||
}
|
||||
|
||||
type GrypeDescriptor struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type GrypeSource struct {
|
||||
Type string `json:"type"`
|
||||
Target map[string]interface{} `json:"target"`
|
||||
}
|
||||
|
||||
type GrypeMatch struct {
|
||||
Vulnerability GrypeVulnerability `json:"vulnerability"`
|
||||
Artifact GrypeArtifact `json:"artifact"`
|
||||
}
|
||||
|
||||
type GrypeVulnerability struct {
|
||||
ID string `json:"id"`
|
||||
Severity string `json:"severity"`
|
||||
Description string `json:"description"`
|
||||
URLs []string `json:"urls"`
|
||||
Fix GrypeFix `json:"fix"`
|
||||
}
|
||||
|
||||
type GrypeFix struct {
|
||||
State string `json:"state"`
|
||||
Versions []string `json:"versions"`
|
||||
}
|
||||
|
||||
type GrypeArtifact struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package npmaudit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "npm-audit"
|
||||
|
||||
// Scanner implements the npm audit vulnerability scanner
|
||||
type Scanner struct {
|
||||
config config.NpmAuditConfig
|
||||
}
|
||||
|
||||
// New creates a new npm audit scanner
|
||||
func New(cfg config.NpmAuditConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans an npm package using npm audit
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
// Only scan npm packages
|
||||
if registry != "npm" {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{
|
||||
"skipped": "npm-audit only supports npm packages",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Starting npm audit scan")
|
||||
|
||||
// Create a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "npm-audit-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Extract the .tgz file
|
||||
if err := s.extractTgz(filePath, tmpDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to extract tgz: %w", err)
|
||||
}
|
||||
|
||||
// Find the package directory (usually "package/")
|
||||
packageDir := filepath.Join(tmpDir, "package")
|
||||
if _, err := os.Stat(packageDir); os.IsNotExist(err) {
|
||||
// Try the tmpDir itself
|
||||
packageDir = tmpDir
|
||||
}
|
||||
|
||||
// Run npm audit
|
||||
cmd := exec.CommandContext(ctx, "npm", "audit", "--json", "--package-lock-only")
|
||||
cmd.Dir = packageDir
|
||||
output, _ := cmd.CombinedOutput() // npm audit returns non-zero when vulns found
|
||||
|
||||
// Parse npm audit output
|
||||
var auditResult NpmAuditResult
|
||||
if len(output) > 0 {
|
||||
if err := json.Unmarshal(output, &auditResult); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse npm audit output")
|
||||
// Return clean result on parse error
|
||||
return s.emptyResult(registry, packageName, version), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to our format
|
||||
result := s.convertResult(&auditResult, registry, packageName, version)
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("npm audit scan completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Health checks if npm is available
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "npm", "--version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("npm not available: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTgz extracts a .tgz file
|
||||
func (s *Scanner) extractTgz(tgzPath, destDir string) error {
|
||||
cmd := exec.Command("tar", "-xzf", tgzPath, "-C", destDir)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// emptyResult returns an empty scan result
|
||||
func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// convertResult converts npm audit output to our ScanResult format
|
||||
func (s *Scanner) convertResult(auditResult *NpmAuditResult, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
// Process vulnerabilities from the audit result
|
||||
for _, vuln := range auditResult.Vulnerabilities {
|
||||
// Normalize severity
|
||||
normalizedSeverity := metadata.NormalizeSeverity(vuln.Severity)
|
||||
severityCounts[normalizedSeverity]++
|
||||
|
||||
// Get references
|
||||
refs := make([]string, 0)
|
||||
if vuln.URL != "" {
|
||||
refs = append(refs, vuln.URL)
|
||||
}
|
||||
for _, ref := range vuln.References {
|
||||
if ref.URL != "" {
|
||||
refs = append(refs, ref.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// Get fixed version
|
||||
fixedIn := ""
|
||||
if vuln.FixAvailable != nil {
|
||||
fixedIn = fmt.Sprintf("%v", vuln.FixAvailable)
|
||||
}
|
||||
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: vuln.Via,
|
||||
Severity: normalizedSeverity,
|
||||
Title: vuln.Name,
|
||||
Description: vuln.Name,
|
||||
References: refs,
|
||||
FixedIn: fixedIn,
|
||||
})
|
||||
}
|
||||
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"severity_counts": severityCounts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NpmAuditResult represents npm audit JSON output
|
||||
type NpmAuditResult struct {
|
||||
AuditReportVersion int `json:"auditReportVersion"`
|
||||
Vulnerabilities map[string]NpmVulnerability `json:"vulnerabilities"`
|
||||
Metadata NpmAuditMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
type NpmVulnerability struct {
|
||||
Name string `json:"name"`
|
||||
Severity string `json:"severity"`
|
||||
Via string `json:"via"`
|
||||
Effects []string `json:"effects"`
|
||||
Range string `json:"range"`
|
||||
FixAvailable interface{} `json:"fixAvailable"`
|
||||
URL string `json:"url"`
|
||||
References []NpmReference `json:"references"`
|
||||
}
|
||||
|
||||
type NpmReference struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type NpmAuditMetadata struct {
|
||||
Vulnerabilities NpmVulnCounts `json:"vulnerabilities"`
|
||||
Dependencies int `json:"dependencies"`
|
||||
}
|
||||
|
||||
type NpmVulnCounts struct {
|
||||
Info int `json:"info"`
|
||||
Low int `json:"low"`
|
||||
Moderate int `json:"moderate"`
|
||||
High int `json:"high"`
|
||||
Critical int `json:"critical"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package osv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// ScannerName is the name of this scanner
|
||||
ScannerName = "osv"
|
||||
|
||||
defaultOSVAPIURL = "https://api.osv.dev/v1/query"
|
||||
)
|
||||
|
||||
// Scanner implements the Scanner interface using OSV.dev API
|
||||
type Scanner struct {
|
||||
config config.OSVConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// OSVRequest represents the request structure for OSV API
|
||||
type OSVRequest struct {
|
||||
Package PackageInfo `json:"package"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// PackageInfo contains package ecosystem and name
|
||||
type PackageInfo struct {
|
||||
Name string `json:"name"`
|
||||
Ecosystem string `json:"ecosystem"` // npm, PyPI, Go, etc.
|
||||
}
|
||||
|
||||
// OSVResponse represents the response from OSV API
|
||||
type OSVResponse struct {
|
||||
Vulns []OSVVulnerability `json:"vulns"`
|
||||
}
|
||||
|
||||
// OSVVulnerability represents a vulnerability in OSV format
|
||||
type OSVVulnerability struct {
|
||||
ID string `json:"id"`
|
||||
Summary string `json:"summary"`
|
||||
Details string `json:"details"`
|
||||
Severity []OSVSeverity `json:"severity,omitempty"`
|
||||
References []OSVReference `json:"references,omitempty"`
|
||||
Affected []OSVAffected `json:"affected"`
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
}
|
||||
|
||||
// OSVSeverity represents severity information
|
||||
type OSVSeverity struct {
|
||||
Type string `json:"type"` // CVSS_V3, etc.
|
||||
Score string `json:"score"` // Severity score
|
||||
}
|
||||
|
||||
// OSVReference represents a reference link
|
||||
type OSVReference struct {
|
||||
Type string `json:"type"` // WEB, ADVISORY, etc.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// OSVAffected represents affected package versions
|
||||
type OSVAffected struct {
|
||||
Package PackageInfo `json:"package"`
|
||||
Ranges []OSVRange `json:"ranges,omitempty"`
|
||||
Versions []string `json:"versions,omitempty"`
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"`
|
||||
}
|
||||
|
||||
// OSVRange represents version ranges
|
||||
type OSVRange struct {
|
||||
Type string `json:"type"` // SEMVER, GIT, etc.
|
||||
Events []OSVEvent `json:"events"`
|
||||
}
|
||||
|
||||
// OSVEvent represents version range events
|
||||
type OSVEvent struct {
|
||||
Introduced string `json:"introduced,omitempty"`
|
||||
Fixed string `json:"fixed,omitempty"`
|
||||
LastAffected string `json:"last_affected,omitempty"`
|
||||
}
|
||||
|
||||
// New creates a new OSV scanner
|
||||
func New(cfg config.OSVConfig) *Scanner {
|
||||
apiURL := cfg.APIURL
|
||||
if apiURL == "" {
|
||||
apiURL = defaultOSVAPIURL
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans a package for vulnerabilities using OSV.dev API
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
// Convert registry to OSV ecosystem
|
||||
ecosystem := s.registryToEcosystem(registry)
|
||||
|
||||
// Build request
|
||||
req := OSVRequest{
|
||||
Package: PackageInfo{
|
||||
Name: packageName,
|
||||
Ecosystem: ecosystem,
|
||||
},
|
||||
Version: version,
|
||||
}
|
||||
|
||||
// Marshal request
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal OSV request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
apiURL := s.config.APIURL
|
||||
if apiURL == "" {
|
||||
apiURL = defaultOSVAPIURL
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OSV request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute request
|
||||
resp, err := s.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OSV API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// Read response
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read OSV response: %w", err)
|
||||
}
|
||||
|
||||
// Check status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("OSV API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var osvResp OSVResponse
|
||||
if err := json.Unmarshal(body, &osvResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse OSV response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to metadata.ScanResult
|
||||
return s.convertOSVResult(&osvResp, registry, packageName, version), nil
|
||||
}
|
||||
|
||||
// registryToEcosystem converts our registry name to OSV ecosystem
|
||||
func (s *Scanner) registryToEcosystem(registry string) string {
|
||||
switch strings.ToLower(registry) {
|
||||
case "npm":
|
||||
return "npm"
|
||||
case "pypi":
|
||||
return "PyPI"
|
||||
case "go":
|
||||
return "Go"
|
||||
case "maven":
|
||||
return "Maven"
|
||||
case "nuget":
|
||||
return "NuGet"
|
||||
case "cargo", "crates":
|
||||
return "crates.io"
|
||||
case "rubygems":
|
||||
return "RubyGems"
|
||||
default:
|
||||
return registry
|
||||
}
|
||||
}
|
||||
|
||||
// convertOSVResult converts OSV response to metadata.ScanResult
|
||||
func (s *Scanner) convertOSVResult(osvResp *OSVResponse, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0, len(osvResp.Vulns))
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
for _, vuln := range osvResp.Vulns {
|
||||
// Determine severity from various sources
|
||||
severity := s.determineSeverity(&vuln)
|
||||
severityCounts[severity]++
|
||||
|
||||
// Extract references
|
||||
references := make([]string, 0, len(vuln.References))
|
||||
for _, ref := range vuln.References {
|
||||
references = append(references, ref.URL)
|
||||
}
|
||||
|
||||
// Find fixed version
|
||||
fixedVersion := s.findFixedVersion(&vuln, version)
|
||||
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: vuln.ID,
|
||||
Severity: severity,
|
||||
Title: vuln.Summary,
|
||||
Description: vuln.Details,
|
||||
References: references,
|
||||
FixedIn: fixedVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine overall status
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: s.Name(),
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"ecosystem": s.registryToEcosystem(registry),
|
||||
"severity_counts": severityCounts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// determineSeverity extracts severity from OSV vulnerability
|
||||
func (s *Scanner) determineSeverity(vuln *OSVVulnerability) string {
|
||||
var rawSeverity string
|
||||
|
||||
// Try to get severity from CVSS
|
||||
for _, sev := range vuln.Severity {
|
||||
if sev.Type == "CVSS_V3" || sev.Type == "CVSS_V2" {
|
||||
// Parse CVSS score to severity
|
||||
score := sev.Score
|
||||
if strings.Contains(strings.ToUpper(score), "CRITICAL") {
|
||||
rawSeverity = "CRITICAL"
|
||||
} else if strings.Contains(strings.ToUpper(score), "HIGH") {
|
||||
rawSeverity = "HIGH"
|
||||
} else if strings.Contains(strings.ToUpper(score), "MEDIUM") || strings.Contains(strings.ToUpper(score), "MODERATE") {
|
||||
rawSeverity = "MODERATE"
|
||||
} else if strings.Contains(strings.ToUpper(score), "LOW") {
|
||||
rawSeverity = "LOW"
|
||||
}
|
||||
if rawSeverity != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check database_specific for severity if not found in CVSS
|
||||
if rawSeverity == "" && vuln.DatabaseSpecific != nil {
|
||||
if sev, ok := vuln.DatabaseSpecific["severity"].(string); ok {
|
||||
rawSeverity = sev
|
||||
}
|
||||
}
|
||||
|
||||
// Default to MODERATE if unknown
|
||||
if rawSeverity == "" {
|
||||
rawSeverity = "MODERATE"
|
||||
}
|
||||
|
||||
// Normalize to standard severity values
|
||||
return metadata.NormalizeSeverity(rawSeverity)
|
||||
}
|
||||
|
||||
// findFixedVersion extracts the fixed version from OSV affected ranges
|
||||
func (s *Scanner) findFixedVersion(vuln *OSVVulnerability, currentVersion string) string {
|
||||
for _, affected := range vuln.Affected {
|
||||
for _, r := range affected.Ranges {
|
||||
for _, event := range r.Events {
|
||||
if event.Fixed != "" {
|
||||
return event.Fixed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Health checks if OSV API is reachable
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
// Make a simple request to check API availability
|
||||
apiURL := s.config.APIURL
|
||||
if apiURL == "" {
|
||||
apiURL = defaultOSVAPIURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", strings.Replace(apiURL, "/query", "", 1), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health check request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("OSV API not reachable: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
log.Debug().Int("status", resp.StatusCode).Msg("OSV health check passed")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package pipaudit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "pip-audit"
|
||||
|
||||
// Scanner implements the pip-audit vulnerability scanner
|
||||
type Scanner struct {
|
||||
config config.PipAuditConfig
|
||||
}
|
||||
|
||||
// New creates a new pip-audit scanner
|
||||
func New(cfg config.PipAuditConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// Scan scans a Python package using pip-audit
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
// Only scan PyPI packages
|
||||
if registry != "pypi" {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{
|
||||
"skipped": "pip-audit only supports PyPI packages",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Starting pip-audit scan")
|
||||
|
||||
// Create a temporary directory
|
||||
tmpDir, err := os.MkdirTemp("", "pip-audit-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Copy the wheel/tar.gz file to temp directory
|
||||
tmpFile := filepath.Join(tmpDir, filepath.Base(filePath))
|
||||
if err := s.copyFile(filePath, tmpFile); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy file: %w", err)
|
||||
}
|
||||
|
||||
// Run pip-audit on the package file
|
||||
cmd := exec.CommandContext(ctx, "pip-audit", "-r", tmpFile, "--format", "json") // #nosec G204 -- pip-audit command with temp file
|
||||
output, _ := cmd.CombinedOutput() // pip-audit returns non-zero when vulns found
|
||||
|
||||
// Parse pip-audit output
|
||||
var auditResult PipAuditResult
|
||||
if len(output) > 0 {
|
||||
if err := json.Unmarshal(output, &auditResult); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse pip-audit output")
|
||||
return s.emptyResult(registry, packageName, version), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to our format
|
||||
result := s.convertResult(&auditResult, registry, packageName, version)
|
||||
|
||||
log.Info().
|
||||
Str("scanner", ScannerName).
|
||||
Str("package", packageName).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("pip-audit scan completed")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Health checks if pip-audit is available
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
cmd := exec.CommandContext(ctx, "pip-audit", "--version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("pip-audit not available: %w (install with: pip install pip-audit)", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyFile copies a file from src to dst
|
||||
func (s *Scanner) copyFile(src, dst string) error {
|
||||
input, err := os.ReadFile(src) // #nosec G304 -- Source path is from scanner, controlled
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(dst, input, 0600)
|
||||
}
|
||||
|
||||
// emptyResult returns an empty scan result
|
||||
func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusClean,
|
||||
VulnerabilityCount: 0,
|
||||
Vulnerabilities: []metadata.Vulnerability{},
|
||||
Details: map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// convertResult converts pip-audit output to our ScanResult format
|
||||
func (s *Scanner) convertResult(auditResult *PipAuditResult, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
for _, dep := range auditResult.Dependencies {
|
||||
for _, vuln := range dep.Vulns {
|
||||
// Map pip-audit severity to our standard
|
||||
severity := s.mapSeverity(vuln.ID)
|
||||
normalizedSeverity := metadata.NormalizeSeverity(severity)
|
||||
severityCounts[normalizedSeverity]++
|
||||
|
||||
// Get fixed versions
|
||||
fixedIn := ""
|
||||
if len(vuln.FixVersions) > 0 {
|
||||
fixedIn = vuln.FixVersions[0]
|
||||
}
|
||||
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: vuln.ID,
|
||||
Severity: normalizedSeverity,
|
||||
Title: vuln.ID,
|
||||
Description: vuln.Description,
|
||||
References: []string{fmt.Sprintf("https://osv.dev/vulnerability/%s", vuln.ID)},
|
||||
FixedIn: fixedIn,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: ScannerName,
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"severity_counts": severityCounts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mapSeverity maps vulnerability ID patterns to severity levels
|
||||
func (s *Scanner) mapSeverity(vulnID string) string {
|
||||
// pip-audit doesn't provide severity directly
|
||||
// Default to MODERATE for all findings
|
||||
return "MODERATE"
|
||||
}
|
||||
|
||||
// PipAuditResult represents pip-audit JSON output
|
||||
type PipAuditResult struct {
|
||||
Dependencies []PipDependency `json:"dependencies"`
|
||||
}
|
||||
|
||||
type PipDependency struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Vulns []PipVuln `json:"vulns"`
|
||||
}
|
||||
|
||||
type PipVuln struct {
|
||||
ID string `json:"id"`
|
||||
Description string `json:"description"`
|
||||
FixVersions []string `json:"fix_versions"`
|
||||
Aliases []string `json:"aliases"`
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package scanner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// RescanWorker handles periodic re-scanning of cached packages
|
||||
type RescanWorker struct {
|
||||
manager *Manager
|
||||
metadataStore metadata.MetadataStore
|
||||
storage storage.StorageBackend
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewRescanWorker creates a new rescan worker
|
||||
func NewRescanWorker(manager *Manager, metadataStore metadata.MetadataStore, storageBackend storage.StorageBackend, interval time.Duration) *RescanWorker {
|
||||
return &RescanWorker{
|
||||
manager: manager,
|
||||
metadataStore: metadataStore,
|
||||
storage: storageBackend,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the periodic re-scanning process
|
||||
func (w *RescanWorker) Start(ctx context.Context) {
|
||||
if !w.manager.enabled || w.interval == 0 {
|
||||
log.Info().Msg("Rescan worker disabled")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Dur("interval", w.interval).
|
||||
Msg("Starting package rescan worker")
|
||||
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run initial scan immediately on startup
|
||||
log.Info().Msg("Running initial package scan on startup")
|
||||
w.rescanPackages(ctx)
|
||||
log.Info().
|
||||
Dur("next_scan", w.interval).
|
||||
Msg("Initial scan complete, next scan scheduled")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
w.rescanPackages(ctx)
|
||||
case <-w.stopCh:
|
||||
log.Info().Msg("Rescan worker stopped")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("Rescan worker stopped (context cancelled)")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the rescan worker
|
||||
func (w *RescanWorker) Stop() {
|
||||
close(w.stopCh)
|
||||
}
|
||||
|
||||
// rescanPackages re-scans packages that need updating
|
||||
func (w *RescanWorker) rescanPackages(ctx context.Context) {
|
||||
log.Info().Msg("Starting package rescan cycle - checking all packages for scan status")
|
||||
|
||||
// Get all packages
|
||||
packages, err := w.metadataStore.ListPackages(ctx, &metadata.ListOptions{})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list packages for rescan")
|
||||
return
|
||||
}
|
||||
|
||||
scanned := 0
|
||||
skipped := 0
|
||||
failed := 0
|
||||
|
||||
for _, pkg := range packages {
|
||||
// Skip metadata entries (npm metadata pages, pypi pages, etc.)
|
||||
if pkg.Version == "list" || pkg.Version == "latest" || pkg.Version == "metadata" || pkg.Version == "page" {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if package needs rescanning
|
||||
needsRescan, err := w.needsRescan(ctx, pkg)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Failed to check rescan status")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if !needsRescan {
|
||||
log.Debug().
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Bool("security_scanned", pkg.SecurityScanned).
|
||||
Msg("Package does not need rescanning, skipping")
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Package needs rescanning")
|
||||
|
||||
// Get file path from storage using the storage key from the package metadata
|
||||
if pkg.StorageKey == "" {
|
||||
log.Warn().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Package has no storage key, skipping rescan")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
filePath, err := w.getPackageFilePath(ctx, pkg.StorageKey)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Err(err).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Str("storage_key", pkg.StorageKey).
|
||||
Msg("Failed to get package file path, skipping rescan")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
log.Debug().
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("No local file path available, skipping rescan")
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the actual scan
|
||||
if err := w.manager.ScanPackage(ctx, pkg.Registry, pkg.Name, pkg.Version, filePath); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("registry", pkg.Registry).
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Failed to rescan package")
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
scanned++
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("total", len(packages)).
|
||||
Int("scanned", scanned).
|
||||
Int("skipped", skipped).
|
||||
Int("failed", failed).
|
||||
Msg("Rescan cycle completed")
|
||||
}
|
||||
|
||||
// needsRescan checks if a package needs to be rescanned
|
||||
func (w *RescanWorker) needsRescan(ctx context.Context, pkg *metadata.Package) (bool, error) {
|
||||
// Get latest scan result
|
||||
scanResult, err := w.metadataStore.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version)
|
||||
if err != nil {
|
||||
// No scan result - needs scanning
|
||||
log.Debug().
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Package has no scan result, needs scanning")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// If package is not marked as scanned but has scan result, it's a stale state - rescan
|
||||
if !pkg.SecurityScanned {
|
||||
log.Info().
|
||||
Str("package", pkg.Name).
|
||||
Str("version", pkg.Version).
|
||||
Msg("Package has scan result but security_scanned flag is false, needs update")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check if scan is older than rescan interval
|
||||
timeSinceLastScan := time.Since(scanResult.ScannedAt)
|
||||
if timeSinceLastScan >= w.interval {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// getPackageFilePath retrieves the local file path for a package from storage
|
||||
func (w *RescanWorker) getPackageFilePath(ctx context.Context, storageKey string) (string, error) {
|
||||
// Check if storage backend supports local paths
|
||||
if localProvider, ok := w.storage.(storage.LocalPathProvider); ok {
|
||||
return localProvider.GetLocalPath(ctx, storageKey)
|
||||
}
|
||||
|
||||
// If storage doesn't support local paths (S3, SMB), we can't rescan
|
||||
return "", nil
|
||||
}
|
||||
@@ -0,0 +1,515 @@
|
||||
package scanner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/ghsa"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/govulncheck"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/grype"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/npmaudit"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/osv"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/pipaudit"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner/trivy"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Scanner defines the interface for security scanners
|
||||
type Scanner interface {
|
||||
// Name returns the scanner name
|
||||
Name() string
|
||||
|
||||
// Scan scans a package for vulnerabilities
|
||||
Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error)
|
||||
|
||||
// Health checks scanner health
|
||||
Health(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DatabaseUpdater is implemented by scanners that need database updates
|
||||
type DatabaseUpdater interface {
|
||||
UpdateDatabase(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Manager manages multiple security scanners
|
||||
type Manager struct {
|
||||
scanners []Scanner
|
||||
enabled bool
|
||||
config config.SecurityConfig
|
||||
metadataStore metadata.MetadataStore
|
||||
}
|
||||
|
||||
// New creates a new scanner manager with configured scanners
|
||||
func New(cfg config.SecurityConfig, metadataStore metadata.MetadataStore) (*Manager, error) {
|
||||
manager := &Manager{
|
||||
scanners: make([]Scanner, 0),
|
||||
enabled: cfg.Enabled,
|
||||
config: cfg,
|
||||
metadataStore: metadataStore,
|
||||
}
|
||||
|
||||
if !cfg.Enabled {
|
||||
log.Info().Msg("Security scanning disabled")
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// Initialize Trivy scanner
|
||||
if cfg.Scanners.Trivy.Enabled {
|
||||
trivyScanner := trivy.New(cfg.Scanners.Trivy)
|
||||
manager.RegisterScanner(trivyScanner)
|
||||
log.Info().Msg("Trivy scanner enabled")
|
||||
|
||||
// Update database on startup if configured
|
||||
if cfg.UpdateDBOnStartup {
|
||||
if err := trivyScanner.UpdateDatabase(context.Background()); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update Trivy database on startup")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize OSV scanner
|
||||
if cfg.Scanners.OSV.Enabled {
|
||||
osvScanner := osv.New(cfg.Scanners.OSV)
|
||||
manager.RegisterScanner(osvScanner)
|
||||
log.Info().Msg("OSV scanner enabled")
|
||||
}
|
||||
|
||||
// Initialize Grype scanner
|
||||
if cfg.Scanners.Grype.Enabled {
|
||||
grypeScanner := grype.New(cfg.Scanners.Grype)
|
||||
manager.RegisterScanner(grypeScanner)
|
||||
log.Info().Msg("Grype scanner enabled")
|
||||
|
||||
// Update database on startup if configured
|
||||
if cfg.UpdateDBOnStartup {
|
||||
if err := grypeScanner.UpdateDatabase(context.Background()); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update Grype database on startup")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize govulncheck scanner
|
||||
if cfg.Scanners.Govulncheck.Enabled {
|
||||
govulncheckScanner := govulncheck.New(cfg.Scanners.Govulncheck)
|
||||
manager.RegisterScanner(govulncheckScanner)
|
||||
log.Info().Msg("govulncheck scanner enabled")
|
||||
}
|
||||
|
||||
// Initialize npm-audit scanner
|
||||
if cfg.Scanners.NpmAudit.Enabled {
|
||||
npmAuditScanner := npmaudit.New(cfg.Scanners.NpmAudit)
|
||||
manager.RegisterScanner(npmAuditScanner)
|
||||
log.Info().Msg("npm-audit scanner enabled")
|
||||
}
|
||||
|
||||
// Initialize pip-audit scanner
|
||||
if cfg.Scanners.PipAudit.Enabled {
|
||||
pipAuditScanner := pipaudit.New(cfg.Scanners.PipAudit)
|
||||
manager.RegisterScanner(pipAuditScanner)
|
||||
log.Info().Msg("pip-audit scanner enabled")
|
||||
}
|
||||
|
||||
// Initialize GitHub Advisory Database scanner
|
||||
if cfg.Scanners.GHSA.Enabled {
|
||||
ghsaScanner := ghsa.New(cfg.Scanners.GHSA)
|
||||
manager.RegisterScanner(ghsaScanner)
|
||||
log.Info().Msg("GitHub Advisory Database scanner enabled")
|
||||
}
|
||||
|
||||
if len(manager.scanners) == 0 {
|
||||
log.Warn().Msg("Security scanning enabled but no scanners configured")
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// RegisterScanner registers a scanner
|
||||
func (m *Manager) RegisterScanner(scanner Scanner) {
|
||||
m.scanners = append(m.scanners, scanner)
|
||||
}
|
||||
|
||||
// ScanPackage scans a package using all registered scanners and saves results
|
||||
func (m *Manager) ScanPackage(ctx context.Context, registry, packageName, version string, filePath string) error {
|
||||
if !m.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("registry", registry).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Msg("Starting security scan")
|
||||
|
||||
// Collect results from all scanners
|
||||
var scanResults []*metadata.ScanResult
|
||||
scannerNames := make([]string, 0)
|
||||
|
||||
for _, scanner := range m.scanners {
|
||||
// Skip scanners that don't support this registry
|
||||
if !m.shouldRunScanner(scanner.Name(), registry) {
|
||||
log.Debug().
|
||||
Str("scanner", scanner.Name()).
|
||||
Str("registry", registry).
|
||||
Msg("Skipping scanner - not compatible with registry")
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := scanner.Scan(ctx, registry, packageName, version, filePath)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("scanner", scanner.Name()).
|
||||
Str("package", packageName).
|
||||
Msg("Scanner failed")
|
||||
continue
|
||||
}
|
||||
|
||||
scanResults = append(scanResults, result)
|
||||
scannerNames = append(scannerNames, scanner.Name())
|
||||
|
||||
log.Info().
|
||||
Str("scanner", scanner.Name()).
|
||||
Str("package", packageName).
|
||||
Str("status", string(result.Status)).
|
||||
Int("vulnerabilities", result.VulnerabilityCount).
|
||||
Msg("Scan completed")
|
||||
}
|
||||
|
||||
// If no scanners succeeded, return
|
||||
if len(scanResults) == 0 {
|
||||
log.Warn().
|
||||
Str("package", packageName).
|
||||
Msg("All scanners failed, no results to save")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merge and deduplicate results from all scanners
|
||||
mergedResult := m.mergeResults(scanResults, scannerNames)
|
||||
|
||||
// Save consolidated result to metadata store
|
||||
if err := m.metadataStore.SaveScanResult(ctx, mergedResult); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("package", packageName).
|
||||
Msg("Failed to save consolidated scan result")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("package", packageName).
|
||||
Str("status", string(mergedResult.Status)).
|
||||
Int("total_vulnerabilities", mergedResult.VulnerabilityCount).
|
||||
Int("unique_cves", len(mergedResult.Vulnerabilities)).
|
||||
Strs("scanners", scannerNames).
|
||||
Msg("Consolidated scan results saved")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeResults merges and deduplicates scan results from multiple scanners
|
||||
func (m *Manager) mergeResults(results []*metadata.ScanResult, scannerNames []string) *metadata.ScanResult {
|
||||
if len(results) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use first result as base
|
||||
merged := &metadata.ScanResult{
|
||||
ID: results[0].ID,
|
||||
Registry: results[0].Registry,
|
||||
PackageName: results[0].PackageName,
|
||||
PackageVersion: results[0].PackageVersion,
|
||||
Scanner: strings.Join(scannerNames, "+"), // Combined scanner name
|
||||
ScannedAt: results[0].ScannedAt,
|
||||
Status: metadata.ScanStatusClean,
|
||||
Vulnerabilities: make([]metadata.Vulnerability, 0),
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Use map for deduplication - key is CVE ID in uppercase
|
||||
vulnMap := make(map[string]*metadata.Vulnerability)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
// Merge vulnerabilities from all scanners
|
||||
for i, result := range results {
|
||||
scannerName := scannerNames[i]
|
||||
|
||||
// Track scanner details
|
||||
merged.Details[scannerName] = result.Details
|
||||
|
||||
// Add/merge vulnerabilities
|
||||
for _, vuln := range result.Vulnerabilities {
|
||||
cveKey := strings.ToUpper(vuln.ID)
|
||||
|
||||
// Check if CVE already exists
|
||||
if existing, exists := vulnMap[cveKey]; exists {
|
||||
// CVE found by multiple scanners - merge information
|
||||
log.Debug().
|
||||
Str("cve", vuln.ID).
|
||||
Strs("existing_scanners", existing.DetectedBy).
|
||||
Str("new_scanner", scannerName).
|
||||
Msg("CVE found by multiple scanners, merging")
|
||||
|
||||
// Add scanner to DetectedBy list
|
||||
existing.DetectedBy = append(existing.DetectedBy, scannerName)
|
||||
|
||||
// Prefer higher severity if different
|
||||
if m.compareSeverity(vuln.Severity, existing.Severity) > 0 {
|
||||
existing.Severity = vuln.Severity
|
||||
}
|
||||
|
||||
// Merge references (deduplicate URLs)
|
||||
refSet := make(map[string]bool)
|
||||
for _, ref := range existing.References {
|
||||
refSet[ref] = true
|
||||
}
|
||||
for _, ref := range vuln.References {
|
||||
if !refSet[ref] {
|
||||
existing.References = append(existing.References, ref)
|
||||
refSet[ref] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer fixed_in version if not already set
|
||||
if existing.FixedIn == "" && vuln.FixedIn != "" {
|
||||
existing.FixedIn = vuln.FixedIn
|
||||
}
|
||||
|
||||
} else {
|
||||
// New CVE - add to map
|
||||
vulnCopy := vuln
|
||||
vulnCopy.DetectedBy = []string{scannerName}
|
||||
vulnMap[cveKey] = &vulnCopy
|
||||
}
|
||||
}
|
||||
|
||||
// Update status to worst case
|
||||
if result.Status == metadata.ScanStatusVulnerable {
|
||||
merged.Status = metadata.ScanStatusVulnerable
|
||||
} else if result.Status == metadata.ScanStatusPending && merged.Status != metadata.ScanStatusVulnerable {
|
||||
merged.Status = metadata.ScanStatusPending
|
||||
}
|
||||
}
|
||||
|
||||
// Convert map to slice and count severities
|
||||
for _, vuln := range vulnMap {
|
||||
merged.Vulnerabilities = append(merged.Vulnerabilities, *vuln)
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
// Update counts
|
||||
merged.VulnerabilityCount = len(merged.Vulnerabilities)
|
||||
merged.Details["severity_counts"] = severityCounts
|
||||
merged.Details["deduplication_summary"] = fmt.Sprintf(
|
||||
"Merged results from %d scanners (%s)",
|
||||
len(scannerNames),
|
||||
strings.Join(scannerNames, ", "),
|
||||
)
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// compareSeverity returns >0 if s1 is more severe than s2, <0 if less, 0 if equal
|
||||
func (m *Manager) compareSeverity(s1, s2 string) int {
|
||||
severityOrder := map[string]int{
|
||||
"CRITICAL": 4,
|
||||
"HIGH": 3,
|
||||
"MODERATE": 2,
|
||||
"MEDIUM": 2, // Support both for backwards compatibility
|
||||
"LOW": 1,
|
||||
"UNKNOWN": 0,
|
||||
}
|
||||
|
||||
v1 := severityOrder[strings.ToUpper(s1)]
|
||||
v2 := severityOrder[strings.ToUpper(s2)]
|
||||
|
||||
return v1 - v2
|
||||
}
|
||||
|
||||
// CheckVulnerabilities checks if a package exceeds vulnerability thresholds
|
||||
func (m *Manager) CheckVulnerabilities(ctx context.Context, registry, packageName, version string) (bool, string, error) {
|
||||
if !m.enabled {
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// Get active CVE bypasses from database
|
||||
bypasses, err := m.metadataStore.GetActiveCVEBypasses(ctx)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get CVE bypasses, continuing without bypasses")
|
||||
bypasses = []*metadata.CVEBypass{} // Continue without bypasses
|
||||
}
|
||||
|
||||
// Check if entire package is bypassed
|
||||
packageKey := fmt.Sprintf("%s/%s@%s", registry, packageName, version)
|
||||
packageKeyNoVersion := fmt.Sprintf("%s/%s", registry, packageName)
|
||||
|
||||
for _, bypass := range bypasses {
|
||||
if bypass.Type == metadata.BypassTypePackage && bypass.Active {
|
||||
if bypass.Target == packageKey || bypass.Target == packageKeyNoVersion {
|
||||
log.Info().
|
||||
Str("package", packageKey).
|
||||
Str("bypass_id", bypass.ID).
|
||||
Str("reason", bypass.Reason).
|
||||
Time("expires_at", bypass.ExpiresAt).
|
||||
Msg("Package bypassed by admin")
|
||||
return false, "", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get latest scan result
|
||||
result, err := m.metadataStore.GetScanResult(ctx, registry, packageName, version)
|
||||
if err != nil {
|
||||
// No scan result found - allow download (will be scanned after)
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// Build set of bypassed CVEs for fast lookup
|
||||
bypassedCVEs := make(map[string]*metadata.CVEBypass)
|
||||
for _, bypass := range bypasses {
|
||||
if bypass.Type == metadata.BypassTypeCVE && bypass.Active {
|
||||
// Check if bypass applies to this package (if AppliesTo is set)
|
||||
if bypass.AppliesTo != "" && bypass.AppliesTo != packageKey && bypass.AppliesTo != packageKeyNoVersion {
|
||||
continue // This bypass doesn't apply to this package
|
||||
}
|
||||
bypassedCVEs[strings.ToUpper(bypass.Target)] = bypass
|
||||
}
|
||||
}
|
||||
|
||||
// Count vulnerabilities by severity, excluding bypassed CVEs
|
||||
severityCounts := make(map[string]int)
|
||||
for _, vuln := range result.Vulnerabilities {
|
||||
// Check if this CVE is bypassed
|
||||
if bypass, ok := bypassedCVEs[strings.ToUpper(vuln.ID)]; ok {
|
||||
log.Debug().
|
||||
Str("cve", vuln.ID).
|
||||
Str("package", packageName).
|
||||
Str("bypass_id", bypass.ID).
|
||||
Str("reason", bypass.Reason).
|
||||
Time("expires_at", bypass.ExpiresAt).
|
||||
Msg("CVE bypassed by admin")
|
||||
continue
|
||||
}
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
// Check against thresholds
|
||||
thresholds := m.config.BlockThresholds
|
||||
|
||||
// Check critical
|
||||
if thresholds.Critical >= 0 && severityCounts["CRITICAL"] > thresholds.Critical {
|
||||
return true, fmt.Sprintf("Package has %d CRITICAL vulnerabilities (threshold: %d)",
|
||||
severityCounts["CRITICAL"], thresholds.Critical), nil
|
||||
}
|
||||
|
||||
// Check high
|
||||
if thresholds.High >= 0 && severityCounts["HIGH"] > thresholds.High {
|
||||
return true, fmt.Sprintf("Package has %d HIGH vulnerabilities (threshold: %d)",
|
||||
severityCounts["HIGH"], thresholds.High), nil
|
||||
}
|
||||
|
||||
// Check moderate (medium)
|
||||
moderateCount := severityCounts["MODERATE"] + severityCounts["MEDIUM"] // Support both for backwards compatibility
|
||||
if thresholds.Medium >= 0 && moderateCount > thresholds.Medium {
|
||||
return true, fmt.Sprintf("Package has %d MODERATE vulnerabilities (threshold: %d)",
|
||||
moderateCount, thresholds.Medium), nil
|
||||
}
|
||||
|
||||
// Check low
|
||||
if thresholds.Low >= 0 && severityCounts["LOW"] > thresholds.Low {
|
||||
return true, fmt.Sprintf("Package has %d LOW vulnerabilities (threshold: %d)",
|
||||
severityCounts["LOW"], thresholds.Low), nil
|
||||
}
|
||||
|
||||
// Check block on severity
|
||||
if m.config.BlockOnSeverity != "" && m.config.BlockOnSeverity != "none" {
|
||||
severity := strings.ToUpper(m.config.BlockOnSeverity)
|
||||
|
||||
// Block if any vulnerabilities at or above the specified severity exist
|
||||
switch severity {
|
||||
case "CRITICAL":
|
||||
if severityCounts["CRITICAL"] > 0 {
|
||||
return true, "Package has CRITICAL vulnerabilities", nil
|
||||
}
|
||||
case "HIGH":
|
||||
if severityCounts["CRITICAL"] > 0 || severityCounts["HIGH"] > 0 {
|
||||
return true, "Package has HIGH or CRITICAL vulnerabilities", nil
|
||||
}
|
||||
case "MODERATE", "MEDIUM":
|
||||
moderateCount := severityCounts["MODERATE"] + severityCounts["MEDIUM"]
|
||||
if severityCounts["CRITICAL"] > 0 || severityCounts["HIGH"] > 0 || moderateCount > 0 {
|
||||
return true, "Package has MODERATE, HIGH, or CRITICAL vulnerabilities", nil
|
||||
}
|
||||
case "LOW":
|
||||
if len(result.Vulnerabilities) > 0 {
|
||||
return true, "Package has vulnerabilities", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
// UpdateDatabases updates vulnerability databases for all scanners
|
||||
func (m *Manager) UpdateDatabases(ctx context.Context) error {
|
||||
if !m.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().Msg("Updating vulnerability databases")
|
||||
|
||||
for _, scanner := range m.scanners {
|
||||
if updater, ok := scanner.(DatabaseUpdater); ok {
|
||||
if err := updater.UpdateDatabase(ctx); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("scanner", scanner.Name()).
|
||||
Msg("Failed to update database")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msg("Vulnerability databases updated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Health checks health of all scanners
|
||||
func (m *Manager) Health(ctx context.Context) error {
|
||||
if !m.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, scanner := range m.scanners {
|
||||
if err := scanner.Health(ctx); err != nil {
|
||||
return fmt.Errorf("scanner %s health check failed: %w", scanner.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldRunScanner determines if a scanner should run for a given registry
|
||||
// Language-specific scanners only run for their target ecosystems
|
||||
func (m *Manager) shouldRunScanner(scannerName, registry string) bool {
|
||||
registry = strings.ToLower(registry)
|
||||
|
||||
// Language-specific scanners - only run for their target registry
|
||||
switch scannerName {
|
||||
case "govulncheck":
|
||||
return registry == "go"
|
||||
case "npm-audit":
|
||||
return registry == "npm"
|
||||
case "pip-audit":
|
||||
return registry == "pypi"
|
||||
|
||||
// Multi-ecosystem scanners - run for all registries
|
||||
case "trivy", "osv", "grype", "github-advisory-database":
|
||||
return true
|
||||
|
||||
// Default: allow scanner to run (for future scanners)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
package trivy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ScannerName is the name of this scanner
|
||||
const ScannerName = "trivy"
|
||||
|
||||
// Scanner implements the Scanner interface using Trivy
|
||||
type Scanner struct {
|
||||
config config.TrivyConfig
|
||||
}
|
||||
|
||||
// TrivyResult represents Trivy JSON output structure
|
||||
type TrivyResult struct {
|
||||
SchemaVersion int `json:"SchemaVersion"`
|
||||
ArtifactName string `json:"ArtifactName"`
|
||||
ArtifactType string `json:"ArtifactType"`
|
||||
Metadata TrivyMetadata `json:"Metadata"`
|
||||
Results []TrivyVulnResult `json:"Results"`
|
||||
}
|
||||
|
||||
type TrivyMetadata struct {
|
||||
OS *TrivyOS `json:"OS,omitempty"`
|
||||
RepoTags []string `json:"RepoTags,omitempty"`
|
||||
RepoDigests []string `json:"RepoDigests,omitempty"`
|
||||
ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type TrivyOS struct {
|
||||
Family string `json:"Family"`
|
||||
Name string `json:"Name"`
|
||||
}
|
||||
|
||||
type TrivyImageConfig struct {
|
||||
Architecture string `json:"architecture"`
|
||||
Created string `json:"created"`
|
||||
}
|
||||
|
||||
type TrivyVulnResult struct {
|
||||
Target string `json:"Target"`
|
||||
Class string `json:"Class"`
|
||||
Type string `json:"Type"`
|
||||
Vulnerabilities []TrivyVulnerability `json:"Vulnerabilities"`
|
||||
}
|
||||
|
||||
type TrivyVulnerability struct {
|
||||
VulnerabilityID string `json:"VulnerabilityID"`
|
||||
PkgName string `json:"PkgName"`
|
||||
InstalledVersion string `json:"InstalledVersion"`
|
||||
FixedVersion string `json:"FixedVersion"`
|
||||
Severity string `json:"Severity"`
|
||||
Title string `json:"Title"`
|
||||
Description string `json:"Description"`
|
||||
References []string `json:"References"`
|
||||
PrimaryURL string `json:"PrimaryURL"`
|
||||
}
|
||||
|
||||
// New creates a new Trivy scanner
|
||||
func New(cfg config.TrivyConfig) *Scanner {
|
||||
return &Scanner{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the scanner name
|
||||
func (s *Scanner) Name() string {
|
||||
return ScannerName
|
||||
}
|
||||
|
||||
// UpdateDatabase updates Trivy's vulnerability database
|
||||
func (s *Scanner) UpdateDatabase(ctx context.Context) error {
|
||||
log.Info().Msg("Updating Trivy vulnerability database")
|
||||
|
||||
cmd := exec.CommandContext(ctx, "trivy", "image", "--download-db-only")
|
||||
if s.config.CacheDB != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("TRIVY_CACHE_DIR=%s", s.config.CacheDB))
|
||||
}
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update Trivy database: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
log.Info().Msg("Trivy vulnerability database updated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan scans a package for vulnerabilities using Trivy
|
||||
func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) {
|
||||
// Set timeout
|
||||
if s.config.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, s.config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Determine scan type based on registry
|
||||
scanType := s.determineScanType(registry, filePath)
|
||||
|
||||
// Build Trivy command
|
||||
args := []string{
|
||||
scanType,
|
||||
"--format", "json",
|
||||
"--quiet",
|
||||
filePath,
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "trivy", args...) // #nosec G204 -- trivy command with controlled arguments
|
||||
|
||||
// Set cache directory if configured
|
||||
if s.config.CacheDB != "" {
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("TRIVY_CACHE_DIR=%s", s.config.CacheDB))
|
||||
}
|
||||
|
||||
// Execute scan
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// Check if it's a timeout
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: s.Name(),
|
||||
ScannedAt: time.Now(),
|
||||
Status: metadata.ScanStatusError,
|
||||
Details: map[string]interface{}{
|
||||
"error": "scan timeout",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("trivy scan failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse Trivy output
|
||||
var trivyResult TrivyResult
|
||||
if err := json.Unmarshal(output, &trivyResult); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Trivy output: %w", err)
|
||||
}
|
||||
|
||||
// Convert to metadata.ScanResult
|
||||
return s.convertTrivyResult(&trivyResult, registry, packageName, version), nil
|
||||
}
|
||||
|
||||
// determineScanType determines the appropriate Trivy scan type
|
||||
func (s *Scanner) determineScanType(registry, filePath string) string {
|
||||
// For now, use filesystem scan for packages
|
||||
// Container image scanning would need different handling
|
||||
ext := strings.ToLower(filePath[strings.LastIndex(filePath, ".")+1:])
|
||||
|
||||
switch registry {
|
||||
case "npm":
|
||||
return "fs" // Filesystem scan for npm packages
|
||||
case "pypi":
|
||||
return "fs" // Filesystem scan for Python packages
|
||||
case "go":
|
||||
return "fs" // Filesystem scan for Go modules
|
||||
default:
|
||||
// Check file extension
|
||||
if ext == "tar" || ext == "tgz" || ext == "gz" {
|
||||
return "fs"
|
||||
}
|
||||
return "fs"
|
||||
}
|
||||
}
|
||||
|
||||
// convertTrivyResult converts Trivy result to metadata.ScanResult
|
||||
func (s *Scanner) convertTrivyResult(trivyResult *TrivyResult, registry, packageName, version string) *metadata.ScanResult {
|
||||
vulnerabilities := make([]metadata.Vulnerability, 0)
|
||||
severityCounts := make(map[string]int)
|
||||
|
||||
// Aggregate all vulnerabilities from all results
|
||||
for _, result := range trivyResult.Results {
|
||||
for _, vuln := range result.Vulnerabilities {
|
||||
// Normalize severity to standard values (CRITICAL, HIGH, MODERATE, LOW)
|
||||
normalizedSeverity := metadata.NormalizeSeverity(vuln.Severity)
|
||||
|
||||
// Count by severity
|
||||
severityCounts[normalizedSeverity]++
|
||||
|
||||
// Add to vulnerabilities list
|
||||
vulnerabilities = append(vulnerabilities, metadata.Vulnerability{
|
||||
ID: vuln.VulnerabilityID,
|
||||
Severity: normalizedSeverity,
|
||||
Title: vuln.Title,
|
||||
Description: vuln.Description,
|
||||
References: vuln.References,
|
||||
FixedIn: vuln.FixedVersion,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Determine overall status
|
||||
status := metadata.ScanStatusClean
|
||||
if len(vulnerabilities) > 0 {
|
||||
status = metadata.ScanStatusVulnerable
|
||||
}
|
||||
|
||||
return &metadata.ScanResult{
|
||||
ID: uuid.New().String(),
|
||||
Registry: registry,
|
||||
PackageName: packageName,
|
||||
PackageVersion: version,
|
||||
Scanner: s.Name(),
|
||||
ScannedAt: time.Now(),
|
||||
Status: status,
|
||||
VulnerabilityCount: len(vulnerabilities),
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Details: map[string]interface{}{
|
||||
"artifact_name": trivyResult.ArtifactName,
|
||||
"artifact_type": trivyResult.ArtifactType,
|
||||
"severity_counts": severityCounts,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Health checks if Trivy is available and working
|
||||
func (s *Scanner) Health(ctx context.Context) error {
|
||||
// Check if trivy command exists
|
||||
cmd := exec.CommandContext(ctx, "trivy", "--version")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("trivy not available: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
log.Debug().Str("version", strings.TrimSpace(string(output))).Msg("Trivy health check passed")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/health"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/logger"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
)
|
||||
|
||||
// Server wraps http.Server with configuration
|
||||
type Server struct {
|
||||
*http.Server
|
||||
config *config.Config
|
||||
healthChecker *health.Checker
|
||||
}
|
||||
|
||||
// New creates a new HTTP server
|
||||
func New(cfg *config.Config, healthChecker *health.Checker) (*Server, error) {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register routes
|
||||
registerRoutes(mux, cfg, healthChecker)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
Handler: logger.Middleware(mux),
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
return &Server{
|
||||
Server: srv,
|
||||
config: cfg,
|
||||
healthChecker: healthChecker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// registerRoutes registers all HTTP routes
|
||||
func registerRoutes(mux *http.ServeMux, cfg *config.Config, healthChecker *health.Checker) {
|
||||
// Health endpoints
|
||||
mux.HandleFunc("/health", healthChecker.HealthHandler())
|
||||
mux.HandleFunc("/health/ready", healthChecker.ReadyHandler())
|
||||
|
||||
// Metrics endpoint
|
||||
mux.Handle("/metrics", metrics.Handler())
|
||||
|
||||
// API endpoints
|
||||
mux.HandleFunc("/api/v1/info", handleInfo(cfg))
|
||||
|
||||
// Package manager proxy endpoints (placeholders for now)
|
||||
if cfg.Handlers.Go.Enabled {
|
||||
mux.HandleFunc("/go/", handleGoProxy())
|
||||
}
|
||||
if cfg.Handlers.NPM.Enabled {
|
||||
mux.HandleFunc("/npm/", handleNPMProxy())
|
||||
}
|
||||
if cfg.Handlers.PyPI.Enabled {
|
||||
mux.HandleFunc("/pypi/", handlePyPIProxy())
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
mux.HandleFunc("/", handleRoot())
|
||||
}
|
||||
|
||||
// handleInfo returns server information
|
||||
func handleInfo(cfg *config.Config) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
info := map[string]interface{}{
|
||||
"name": "GoHoarder",
|
||||
"version": "dev",
|
||||
"handlers": map[string]bool{
|
||||
"go": cfg.Handlers.Go.Enabled,
|
||||
"npm": cfg.Handlers.NPM.Enabled,
|
||||
"pypi": cfg.Handlers.PyPI.Enabled,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"success":true,"data":%v}`, toJSON(info))
|
||||
}
|
||||
}
|
||||
|
||||
// handleGoProxy handles Go module proxy requests
|
||||
func handleGoProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement Go proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"Go proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNPMProxy handles NPM registry requests
|
||||
func handleNPMProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement NPM proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"NPM proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePyPIProxy handles PyPI requests
|
||||
func handlePyPIProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement PyPI proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"PyPI proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRoot handles root path
|
||||
func handleRoot() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"success":true,"data":{"message":"GoHoarder - Universal Package Cache Proxy","docs":"https://github.com/lukaszraczylo/gohoarder"}}`)
|
||||
}
|
||||
}
|
||||
|
||||
// toJSON is a simple JSON encoder (replace with proper implementation)
|
||||
func toJSON(v interface{}) string {
|
||||
// Simplified for now - proper implementation would use goccy/go-json
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
@@ -0,0 +1,415 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for file checksums, not cryptographic security
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FilesystemStorage implements storage.StorageBackend for local filesystem
|
||||
type FilesystemStorage struct {
|
||||
basePath string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
}
|
||||
|
||||
// New creates a new filesystem storage backend
|
||||
func New(basePath string, quota int64) (*FilesystemStorage, error) {
|
||||
// Create base directory if it doesn't exist
|
||||
if err := os.MkdirAll(basePath, 0750); err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create base directory")
|
||||
}
|
||||
|
||||
fs := &FilesystemStorage{
|
||||
basePath: basePath,
|
||||
quota: quota,
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := fs.calculateUsage(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial storage usage")
|
||||
}
|
||||
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
// Get retrieves a file
|
||||
func (fs *FilesystemStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
// Check context
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
file, err := os.Open(path) // #nosec G304 -- Path is sanitized storage key
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("filesystem", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("filesystem", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open file")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "get", "success")
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Put stores a file atomically
|
||||
func (fs *FilesystemStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
// Check context
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
// Create directory
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create directory")
|
||||
}
|
||||
|
||||
// Create temp file for atomic write
|
||||
tempPath := path + ".tmp"
|
||||
tempFile, err := os.Create(tempPath) // #nosec G304 -- Temp path is constructed from sanitized storage key
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create temp file")
|
||||
}
|
||||
|
||||
// Calculate checksums while writing
|
||||
// NOTE: MD5 is used for integrity verification (checksums), not cryptographic security
|
||||
md5Hash := md5.New() // #nosec G401 -- MD5 used for file integrity check, not cryptographic security
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(tempFile, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
tempFile.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write data")
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to close temp file")
|
||||
}
|
||||
|
||||
// Check quota
|
||||
fs.mu.Lock()
|
||||
if fs.quota > 0 && fs.used+written > fs.quota {
|
||||
fs.mu.Unlock()
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
metrics.RecordStorageOperation("filesystem", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(fs.quota)
|
||||
}
|
||||
fs.used += written
|
||||
fs.mu.Unlock()
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
fs.mu.Lock()
|
||||
fs.used -= written
|
||||
currentUsed := fs.used
|
||||
fs.mu.Unlock()
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename temp file")
|
||||
}
|
||||
|
||||
fs.mu.RLock()
|
||||
currentUsed := fs.used
|
||||
fs.mu.RUnlock()
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "put", "success")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file
|
||||
func (fs *FilesystemStorage) Delete(ctx context.Context, key string) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
// Get size before deletion
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "not_found")
|
||||
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
size := info.Size()
|
||||
|
||||
if err := os.Remove(path); err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete file")
|
||||
}
|
||||
|
||||
fs.mu.Lock()
|
||||
fs.used -= size
|
||||
currentUsed := fs.used
|
||||
fs.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "success")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists
|
||||
func (fs *FilesystemStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
_, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence")
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix
|
||||
func (fs *FilesystemStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
searchPath := fs.keyToPath(prefix)
|
||||
var objects []storage.StorageObject
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert path back to key
|
||||
relPath, _ := filepath.Rel(fs.basePath, path)
|
||||
key := filepath.ToSlash(relPath)
|
||||
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list files")
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata
|
||||
func (fs *FilesystemStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
return &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (fs *FilesystemStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
fs.mu.RLock()
|
||||
used := fs.used
|
||||
fs.mu.RUnlock()
|
||||
|
||||
available := fs.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: fs.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks filesystem health
|
||||
func (fs *FilesystemStorage) Health(ctx context.Context) error {
|
||||
// Check if base path is accessible
|
||||
if _, err := os.Stat(fs.basePath); err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "base path not accessible")
|
||||
}
|
||||
|
||||
// Try to create a temp file (sanitize path to prevent traversal)
|
||||
tempPath := filepath.Clean(filepath.Join(fs.basePath, ".health_check"))
|
||||
f, err := os.Create(tempPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "cannot write to storage")
|
||||
}
|
||||
f.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
_ = os.Remove(tempPath) // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (fs *FilesystemStorage) Close() error {
|
||||
// Nothing to close for filesystem
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLocalPath returns the local filesystem path for a storage key
|
||||
// This implements storage.LocalPathProvider interface
|
||||
func (fs *FilesystemStorage) GetLocalPath(ctx context.Context, key string) (string, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return "", errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// keyToPath converts a storage key to filesystem path
|
||||
func (fs *FilesystemStorage) keyToPath(key string) string {
|
||||
// Sanitize key to prevent path traversal
|
||||
key = filepath.Clean(key)
|
||||
|
||||
// Remove any leading slashes or dots
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
|
||||
// Keep removing ../ until there are no more
|
||||
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = strings.TrimPrefix(key, "../")
|
||||
key = strings.TrimPrefix(key, "..\\")
|
||||
}
|
||||
|
||||
// Final clean and ensure it's within base path
|
||||
key = filepath.Clean(key)
|
||||
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = ""
|
||||
}
|
||||
|
||||
return filepath.Join(fs.basePath, key)
|
||||
}
|
||||
|
||||
// calculateUsage calculates current storage usage
|
||||
func (fs *FilesystemStorage) calculateUsage() error {
|
||||
var total int64
|
||||
|
||||
err := filepath.Walk(fs.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
if !info.IsDir() {
|
||||
total += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fs.mu.Lock()
|
||||
fs.used = total
|
||||
fs.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("filesystem", total)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,757 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type FilesystemStorageTestSuite struct {
|
||||
suite.Suite
|
||||
tempDir string
|
||||
fs *FilesystemStorage
|
||||
}
|
||||
|
||||
func (s *FilesystemStorageTestSuite) SetupTest() {
|
||||
var err error
|
||||
s.tempDir, err = os.MkdirTemp("", "gohoarder-test-*")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.fs, err = New(s.tempDir, 1024*1024) // 1MB quota
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *FilesystemStorageTestSuite) TearDownTest() {
|
||||
if s.fs != nil {
|
||||
s.fs.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
if s.tempDir != "" {
|
||||
_ = os.RemoveAll(s.tempDir) // #nosec G104 -- Cleanup
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemStorageTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(FilesystemStorageTestSuite))
|
||||
}
|
||||
|
||||
// Test Put operation
|
||||
func (s *FilesystemStorageTestSuite) TestPut() {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
data string
|
||||
opts *storage.PutOptions
|
||||
expectError bool
|
||||
errorCheck func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "successful put",
|
||||
key: "test/file.txt",
|
||||
data: "hello world",
|
||||
opts: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with valid MD5 checksum",
|
||||
key: "test/checksummed.txt",
|
||||
data: "test data",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "eb733a00c0c9d336e65691a37ab54293"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with invalid MD5 checksum",
|
||||
key: "test/bad-checksum.txt",
|
||||
data: "test data",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "put with nested path",
|
||||
key: "deep/nested/path/file.txt",
|
||||
data: "nested content",
|
||||
opts: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with path traversal attempt",
|
||||
key: "../../../etc/passwd",
|
||||
data: "malicious",
|
||||
opts: nil,
|
||||
expectError: false, // Should be sanitized, not error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
ctx := context.Background()
|
||||
reader := strings.NewReader(tt.data)
|
||||
|
||||
err := s.fs.Put(ctx, tt.key, reader, tt.opts)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
// Verify file exists
|
||||
exists, err := s.fs.Exists(ctx, tt.key)
|
||||
s.NoError(err)
|
||||
s.True(exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Get operation
|
||||
func (s *FilesystemStorageTestSuite) TestGet() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
testData := "test content for retrieval"
|
||||
err := s.fs.Put(ctx, "test/get.txt", strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expectError bool
|
||||
expectData string
|
||||
}{
|
||||
{
|
||||
name: "get existing file",
|
||||
key: "test/get.txt",
|
||||
expectError: false,
|
||||
expectData: testData,
|
||||
},
|
||||
{
|
||||
name: "get non-existent file",
|
||||
key: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
reader, err := s.fs.Get(ctx, tt.key)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
s.Nil(reader)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(reader)
|
||||
defer reader.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.expectData, string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Delete operation
|
||||
func (s *FilesystemStorageTestSuite) TestDelete() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupKey string
|
||||
deleteKey string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete existing file",
|
||||
setupKey: "test/delete-me.txt",
|
||||
deleteKey: "test/delete-me.txt",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existent file",
|
||||
setupKey: "",
|
||||
deleteKey: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// Setup
|
||||
if tt.setupKey != "" {
|
||||
err := s.fs.Put(ctx, tt.setupKey, strings.NewReader("to be deleted"), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test delete
|
||||
err := s.fs.Delete(ctx, tt.deleteKey)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
// Verify file no longer exists
|
||||
exists, err := s.fs.Exists(ctx, tt.deleteKey)
|
||||
s.NoError(err)
|
||||
s.False(exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Exists operation
|
||||
func (s *FilesystemStorageTestSuite) TestExists() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
err := s.fs.Put(ctx, "test/exists.txt", strings.NewReader("content"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
exists bool
|
||||
}{
|
||||
{
|
||||
name: "existing file",
|
||||
key: "test/exists.txt",
|
||||
exists: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
key: "test/does-not-exist.txt",
|
||||
exists: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
exists, err := s.fs.Exists(ctx, tt.key)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.exists, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test List operation
|
||||
func (s *FilesystemStorageTestSuite) TestList() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create multiple files
|
||||
files := []string{
|
||||
"packages/npm/react/17.0.1/package.json",
|
||||
"packages/npm/react/17.0.2/package.json",
|
||||
"packages/npm/vue/3.0.0/package.json",
|
||||
"packages/pypi/django/3.2.0/wheel.whl",
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
err := s.fs.Put(ctx, file, strings.NewReader("content"), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
opts *storage.ListOptions
|
||||
expectedCount int
|
||||
expectedKeys []string
|
||||
}{
|
||||
{
|
||||
name: "list all npm packages",
|
||||
prefix: "packages/npm",
|
||||
opts: nil,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list react packages",
|
||||
prefix: "packages/npm/react",
|
||||
opts: nil,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list with pagination",
|
||||
prefix: "packages/npm",
|
||||
opts: &storage.ListOptions{MaxResults: 2, Offset: 0},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list with offset",
|
||||
prefix: "packages/npm",
|
||||
opts: &storage.ListOptions{MaxResults: 2, Offset: 1},
|
||||
expectedCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
objects, err := s.fs.List(ctx, tt.prefix, tt.opts)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.expectedCount, len(objects))
|
||||
|
||||
// Verify objects have required fields
|
||||
for _, obj := range objects {
|
||||
s.NotEmpty(obj.Key)
|
||||
s.Greater(obj.Size, int64(0))
|
||||
s.False(obj.Modified.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Stat operation
|
||||
func (s *FilesystemStorageTestSuite) TestStat() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
testData := "stat test content"
|
||||
testKey := "test/stat.txt"
|
||||
err := s.fs.Put(ctx, testKey, strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "stat existing file",
|
||||
key: testKey,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "stat non-existent file",
|
||||
key: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
info, err := s.fs.Stat(ctx, tt.key)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
s.Nil(info)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(info)
|
||||
s.Equal(tt.key, info.Key)
|
||||
s.Equal(int64(len(testData)), info.Size)
|
||||
s.False(info.Modified.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Quota enforcement
|
||||
func (s *FilesystemStorageTestSuite) TestQuotaEnforcement() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new filesystem with small quota (100 bytes)
|
||||
smallQuotaDir, err := os.MkdirTemp("", "gohoarder-quota-*")
|
||||
s.Require().NoError(err)
|
||||
defer os.RemoveAll(smallQuotaDir)
|
||||
|
||||
smallFs, err := New(smallQuotaDir, 100)
|
||||
s.Require().NoError(err)
|
||||
defer smallFs.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
// First write should succeed
|
||||
err = smallFs.Put(ctx, "file1.txt", strings.NewReader("small content"), nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Large write should fail due to quota
|
||||
largeData := strings.Repeat("x", 200)
|
||||
err = smallFs.Put(ctx, "large.txt", strings.NewReader(largeData), nil)
|
||||
s.Error(err)
|
||||
|
||||
// Verify quota info
|
||||
quotaInfo, err := smallFs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(100), quotaInfo.Limit)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
s.LessOrEqual(quotaInfo.Used, quotaInfo.Limit)
|
||||
}
|
||||
|
||||
// Test GetQuota operation
|
||||
func (s *FilesystemStorageTestSuite) TestGetQuota() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Put some files
|
||||
err := s.fs.Put(ctx, "file1.txt", strings.NewReader("content1"), nil)
|
||||
s.Require().NoError(err)
|
||||
err = s.fs.Put(ctx, "file2.txt", strings.NewReader("content2"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
quotaInfo, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.NotNil(quotaInfo)
|
||||
s.Equal(int64(1024*1024), quotaInfo.Limit)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
s.Greater(quotaInfo.Available, int64(0))
|
||||
s.Equal(quotaInfo.Limit, quotaInfo.Used+quotaInfo.Available)
|
||||
}
|
||||
|
||||
// Test Health check
|
||||
func (s *FilesystemStorageTestSuite) TestHealth() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Healthy filesystem
|
||||
err := s.fs.Health(ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Unhealthy filesystem (removed directory)
|
||||
badDir := filepath.Join(s.tempDir, "nonexistent")
|
||||
badFs := &FilesystemStorage{basePath: badDir}
|
||||
err = badFs.Health(ctx)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Test Context cancellation
|
||||
func (s *FilesystemStorageTestSuite) TestContextCancellation() {
|
||||
// Create cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{
|
||||
name: "Get with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Get(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Put with cancelled context",
|
||||
fn: func() error {
|
||||
return s.fs.Put(ctx, "test.txt", strings.NewReader("data"), nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Delete with cancelled context",
|
||||
fn: func() error {
|
||||
return s.fs.Delete(ctx, "test.txt")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Exists with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Exists(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "List with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.List(ctx, "test", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Stat with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Stat(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := tt.fn()
|
||||
s.Error(err)
|
||||
s.Equal(context.Canceled, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent access (race condition testing)
|
||||
func (s *FilesystemStorageTestSuite) TestConcurrentAccess() {
|
||||
ctx := context.Background()
|
||||
numGoroutines := 10
|
||||
numOperations := 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
key := fmt.Sprintf("concurrent/%d/%d.txt", id, j)
|
||||
data := fmt.Sprintf("data-%d-%d", id, j)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
s.NoError(err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all files exist
|
||||
objects, err := s.fs.List(ctx, "concurrent", nil)
|
||||
s.NoError(err)
|
||||
s.Equal(numGoroutines*numOperations, len(objects))
|
||||
}
|
||||
|
||||
// Test concurrent reads and writes
|
||||
func (s *FilesystemStorageTestSuite) TestConcurrentReadsAndWrites() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create some initial files
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("shared/file-%d.txt", i)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(fmt.Sprintf("initial-%d", i)), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numReaders := 5
|
||||
numWriters := 5
|
||||
numOps := 50
|
||||
|
||||
// Concurrent readers
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOps; j++ {
|
||||
key := fmt.Sprintf("shared/file-%d.txt", j%10)
|
||||
reader, err := s.fs.Get(ctx, key)
|
||||
if err == nil {
|
||||
io.ReadAll(reader)
|
||||
reader.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent writers
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOps; j++ {
|
||||
key := fmt.Sprintf("shared/writer-%d-%d.txt", id, j)
|
||||
data := fmt.Sprintf("writer-%d-%d", id, j)
|
||||
s.fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify quota tracking is consistent
|
||||
quotaInfo, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
}
|
||||
|
||||
// Test Delete updates quota correctly
|
||||
func (s *FilesystemStorageTestSuite) TestDeleteUpdatesQuota() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Put a file
|
||||
testData := "test data for quota tracking"
|
||||
err := s.fs.Put(ctx, "quota/test.txt", strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Get quota before delete
|
||||
quotaBefore, err := s.fs.GetQuota(ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Delete the file
|
||||
err = s.fs.Delete(ctx, "quota/test.txt")
|
||||
s.NoError(err)
|
||||
|
||||
// Get quota after delete
|
||||
quotaAfter, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Quota should have decreased
|
||||
s.Less(quotaAfter.Used, quotaBefore.Used)
|
||||
}
|
||||
|
||||
// Test atomic write behavior
|
||||
func (s *FilesystemStorageTestSuite) TestAtomicWrite() {
|
||||
ctx := context.Background()
|
||||
key := "atomic/test.txt"
|
||||
|
||||
// Initial write
|
||||
err := s.fs.Put(ctx, key, strings.NewReader("initial"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Concurrent readers should never see partial writes
|
||||
var wg sync.WaitGroup
|
||||
stopReading := make(chan struct{})
|
||||
readErrors := make(chan error, 100)
|
||||
|
||||
// Start readers
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopReading:
|
||||
return
|
||||
default:
|
||||
reader, err := s.fs.Get(ctx, key)
|
||||
if err != nil {
|
||||
readErrors <- err
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
if err != nil {
|
||||
readErrors <- err
|
||||
continue
|
||||
}
|
||||
// Data should be either "initial" or "updated", never partial
|
||||
content := string(data)
|
||||
if content != "initial" && content != "updated" {
|
||||
readErrors <- fmt.Errorf("read partial data: %s", content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Perform update
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err = s.fs.Put(ctx, key, strings.NewReader("updated"), nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Stop readers
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(stopReading)
|
||||
wg.Wait()
|
||||
close(readErrors)
|
||||
|
||||
// Check for read errors
|
||||
for err := range readErrors {
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test path sanitization
|
||||
func (s *FilesystemStorageTestSuite) TestPathSanitization() {
|
||||
ctx := context.Background()
|
||||
|
||||
maliciousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"/../secret.txt",
|
||||
"./../../outside.txt",
|
||||
"//etc/passwd",
|
||||
}
|
||||
|
||||
for _, path := range maliciousPaths {
|
||||
s.Run(fmt.Sprintf("sanitize_%s", path), func() {
|
||||
err := s.fs.Put(ctx, path, strings.NewReader("malicious"), nil)
|
||||
s.NoError(err) // Should succeed but sanitize path
|
||||
|
||||
// Verify file is inside base directory
|
||||
sanitized := s.fs.keyToPath(path)
|
||||
s.True(strings.HasPrefix(sanitized, s.tempDir),
|
||||
"Sanitized path %s should be inside %s", sanitized, s.tempDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test checksum validation
|
||||
func (s *FilesystemStorageTestSuite) TestChecksumValidation() {
|
||||
ctx := context.Background()
|
||||
|
||||
testData := "checksum test data"
|
||||
// Correct checksums calculated for "checksum test data"
|
||||
correctMD5 := "7dd7323e8ce3e087972f93d3711ef62b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *storage.PutOptions
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid MD5",
|
||||
opts: &storage.PutOptions{ChecksumMD5: correctMD5},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid MD5",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty checksum (no validation)",
|
||||
opts: &storage.PutOptions{ChecksumMD5: ""},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
key := fmt.Sprintf("checksum/%s.txt", tt.name)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(testData), tt.opts)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Put operation
|
||||
func BenchmarkFilesystemPut(b *testing.B) {
|
||||
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fs, _ := New(tempDir, 1024*1024*1024) // 1GB quota
|
||||
defer fs.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
ctx := context.Background()
|
||||
data := strings.Repeat("x", 1024) // 1KB
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("bench/file-%d.txt", i)
|
||||
fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Get operation
|
||||
func BenchmarkFilesystemGet(b *testing.B) {
|
||||
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fs, _ := New(tempDir, 1024*1024*1024)
|
||||
defer fs.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
ctx := context.Background()
|
||||
data := strings.Repeat("x", 1024)
|
||||
|
||||
// Setup: Create test file
|
||||
fs.Put(ctx, "bench/test.txt", strings.NewReader(data), nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader, _ := fs.Get(ctx, "bench/test.txt")
|
||||
if reader != nil {
|
||||
io.ReadAll(reader)
|
||||
reader.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StorageBackend defines the interface for package storage
|
||||
type StorageBackend interface {
|
||||
// Get retrieves a package by key
|
||||
Get(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
|
||||
// Put stores a package
|
||||
Put(ctx context.Context, key string, data io.Reader, opts *PutOptions) error
|
||||
|
||||
// Delete removes a package
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// Exists checks if a package exists
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// List lists packages with prefix
|
||||
List(ctx context.Context, prefix string, opts *ListOptions) ([]StorageObject, error)
|
||||
|
||||
// Stat gets package metadata
|
||||
Stat(ctx context.Context, key string) (*StorageInfo, error)
|
||||
|
||||
// GetQuota returns quota information
|
||||
GetQuota(ctx context.Context) (*QuotaInfo, error)
|
||||
|
||||
// Health checks backend health
|
||||
Health(ctx context.Context) error
|
||||
|
||||
// Close closes the backend
|
||||
Close() error
|
||||
}
|
||||
|
||||
// PutOptions contains options for Put operations
|
||||
type PutOptions struct {
|
||||
ContentType string
|
||||
Metadata map[string]string
|
||||
ChecksumMD5 string
|
||||
ChecksumSHA256 string
|
||||
}
|
||||
|
||||
// ListOptions contains options for List operations
|
||||
type ListOptions struct {
|
||||
MaxResults int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// StorageObject represents a stored object
|
||||
type StorageObject struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
ETag string
|
||||
}
|
||||
|
||||
// StorageInfo contains detailed object information
|
||||
type StorageInfo struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
ETag string
|
||||
ContentType string
|
||||
Metadata map[string]string
|
||||
Checksums *Checksums
|
||||
}
|
||||
|
||||
// Checksums contains file checksums
|
||||
type Checksums struct {
|
||||
MD5 string
|
||||
SHA256 string
|
||||
}
|
||||
|
||||
// QuotaInfo contains quota information
|
||||
type QuotaInfo struct {
|
||||
Used int64
|
||||
Available int64
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// LocalPathProvider is an optional interface that storage backends can implement
|
||||
// to provide direct file system paths for scanning without creating temp copies
|
||||
type LocalPathProvider interface {
|
||||
// GetLocalPath returns the local filesystem path for a storage key
|
||||
// Returns empty string if the backend doesn't support local paths (e.g., S3, SMB)
|
||||
GetLocalPath(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for S3 Content-MD5 header, not cryptographic security
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// S3Storage implements storage.StorageBackend for AWS S3
|
||||
type S3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
}
|
||||
|
||||
// Config holds S3 configuration
|
||||
type Config struct {
|
||||
Bucket string
|
||||
Region string
|
||||
Endpoint string // For S3-compatible services (MinIO, etc.)
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
Prefix string // Optional prefix for all keys
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
ForcePathStyle bool // For S3-compatible services
|
||||
}
|
||||
|
||||
// New creates a new S3 storage backend
|
||||
func New(ctx context.Context, cfg Config) (*S3Storage, error) {
|
||||
if cfg.Bucket == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 bucket is required")
|
||||
}
|
||||
|
||||
if cfg.Region == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 region is required")
|
||||
}
|
||||
|
||||
// Build AWS config
|
||||
var awsCfg aws.Config
|
||||
var err error
|
||||
|
||||
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
|
||||
// Use static credentials
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
cfg.AccessKeyID,
|
||||
cfg.SecretAccessKey,
|
||||
"",
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
// Use default credential chain
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to load AWS config")
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
var s3Options []func(*s3.Options)
|
||||
|
||||
if cfg.Endpoint != "" {
|
||||
s3Options = append(s3Options, func(o *s3.Options) {
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||||
o.UsePathStyle = cfg.ForcePathStyle
|
||||
})
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, s3Options...)
|
||||
|
||||
s3Storage := &S3Storage{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: strings.TrimSuffix(cfg.Prefix, "/"),
|
||||
quota: cfg.Quota,
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := s3Storage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial S3 storage usage")
|
||||
}
|
||||
|
||||
return s3Storage, nil
|
||||
}
|
||||
|
||||
// Get retrieves a file from S3
|
||||
func (s *S3Storage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
result, err := s.client.GetObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
metrics.RecordStorageOperation("s3", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("s3", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get object from S3")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("s3", "get", "success")
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
// Put stores a file in S3
|
||||
func (s *S3Storage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New() // #nosec G401 -- MD5 used for S3 integrity check, not cryptographic security
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
}
|
||||
|
||||
// Check quota before upload
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("s3", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare metadata
|
||||
metadata := make(map[string]string)
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
metadata = opts.Metadata
|
||||
}
|
||||
|
||||
// Build put input
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
Body: bytes.NewReader(buf.Bytes()),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
if opts != nil && opts.ContentType != "" {
|
||||
input.ContentType = aws.String(opts.ContentType)
|
||||
}
|
||||
|
||||
// Upload to S3
|
||||
_, err = s.client.PutObject(ctx, input)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to upload to S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "put", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from S3
|
||||
func (s *S3Storage) Delete(ctx context.Context, key string) error {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
// Get size before deletion for quota tracking
|
||||
statInfo, err := s.Stat(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
input := &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
_, err = s.client.DeleteObject(ctx, input)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete from S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= statInfo.Size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "delete", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists in S3
|
||||
func (s *S3Storage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
_, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence in S3")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix in S3
|
||||
func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
s3Prefix := s.buildKey(prefix)
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(s3Prefix),
|
||||
}
|
||||
|
||||
var objects []storage.StorageObject
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list objects in S3")
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
key := s.stripPrefix(*obj.Key)
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: *obj.Size,
|
||||
Modified: *obj.LastModified,
|
||||
ETag: strings.Trim(*obj.ETag, "\""),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from S3
|
||||
func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
result, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat object in S3")
|
||||
}
|
||||
|
||||
info := &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: *result.ContentLength,
|
||||
Modified: *result.LastModified,
|
||||
ETag: strings.Trim(*result.ETag, "\""),
|
||||
Metadata: result.Metadata,
|
||||
}
|
||||
|
||||
if result.ContentType != nil {
|
||||
info.ContentType = *result.ContentType
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (s *S3Storage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks S3 health
|
||||
func (s *S3Storage) Health(ctx context.Context) error {
|
||||
// Try to list bucket to verify connectivity
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
MaxKeys: aws.Int32(1),
|
||||
}
|
||||
|
||||
_, err := s.client.ListObjectsV2(ctx, input)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "S3 health check failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (s *S3Storage) Close() error {
|
||||
// No cleanup needed for S3 client
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildKey builds the full S3 key with prefix
|
||||
func (s *S3Storage) buildKey(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
if s.prefix != "" {
|
||||
return s.prefix + "/" + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// stripPrefix removes the configured prefix from an S3 key
|
||||
func (s *S3Storage) stripPrefix(s3Key string) string {
|
||||
if s.prefix != "" {
|
||||
return strings.TrimPrefix(s3Key, s.prefix+"/")
|
||||
}
|
||||
return s3Key
|
||||
}
|
||||
|
||||
// calculateUsage calculates current S3 storage usage
|
||||
func (s *S3Storage) calculateUsage(ctx context.Context) error {
|
||||
var total int64
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
}
|
||||
|
||||
if s.prefix != "" {
|
||||
input.Prefix = aws.String(s.prefix + "/")
|
||||
}
|
||||
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
total += *obj.Size
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("s3", total)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isNotFoundError checks if an error is a "not found" error
|
||||
func isNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var notFound *types.NotFound
|
||||
var noSuchKey *types.NoSuchKey
|
||||
|
||||
return stderrors.As(err, ¬Found) || stderrors.As(err, &noSuchKey)
|
||||
}
|
||||
@@ -0,0 +1,579 @@
|
||||
package smb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for file checksums, not cryptographic security
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SMBStorage implements storage.StorageBackend for SMB/CIFS shares
|
||||
type SMBStorage struct {
|
||||
host string
|
||||
share string
|
||||
basePath string
|
||||
username string
|
||||
password string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
connPool chan *smbConnection
|
||||
poolSize int
|
||||
}
|
||||
|
||||
// smbConnection wraps an SMB session and share
|
||||
type smbConnection struct {
|
||||
conn net.Conn
|
||||
session *smb2.Session
|
||||
share *smb2.Share
|
||||
lastUse time.Time
|
||||
}
|
||||
|
||||
// Config holds SMB configuration
|
||||
type Config struct {
|
||||
Host string // SMB server hostname or IP
|
||||
Port int // SMB server port (default: 445)
|
||||
Share string // SMB share name
|
||||
BasePath string // Base path within the share
|
||||
Username string // SMB username
|
||||
Password string // SMB password
|
||||
Domain string // SMB domain (optional)
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
PoolSize int // Connection pool size (default: 5)
|
||||
}
|
||||
|
||||
// New creates a new SMB storage backend
|
||||
func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required")
|
||||
}
|
||||
|
||||
if cfg.Share == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required")
|
||||
}
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 445 // Default SMB port
|
||||
}
|
||||
|
||||
if cfg.PoolSize == 0 {
|
||||
cfg.PoolSize = 5 // Default pool size
|
||||
}
|
||||
|
||||
smbStorage := &SMBStorage{
|
||||
host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
share: cfg.Share,
|
||||
basePath: strings.Trim(cfg.BasePath, "/\\"),
|
||||
username: cfg.Username,
|
||||
password: cfg.Password,
|
||||
quota: cfg.Quota,
|
||||
connPool: make(chan *smbConnection, cfg.PoolSize),
|
||||
poolSize: cfg.PoolSize,
|
||||
}
|
||||
|
||||
// Initialize connection pool
|
||||
for i := 0; i < cfg.PoolSize; i++ {
|
||||
conn, err := smbStorage.createConnection(ctx)
|
||||
if err != nil {
|
||||
// Clean up any created connections
|
||||
close(smbStorage.connPool)
|
||||
for c := range smbStorage.connPool {
|
||||
c.close()
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool")
|
||||
}
|
||||
smbStorage.connPool <- conn
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := smbStorage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage")
|
||||
}
|
||||
|
||||
return smbStorage, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new SMB connection
|
||||
func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) {
|
||||
conn, err := net.Dial("tcp", s.host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: s.username,
|
||||
Password: s.password,
|
||||
},
|
||||
}
|
||||
|
||||
session, err := dialer.Dial(conn)
|
||||
if err != nil {
|
||||
conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, err
|
||||
}
|
||||
|
||||
share, err := session.Mount(s.share)
|
||||
if err != nil {
|
||||
_ = session.Logoff() // #nosec G104 -- SMB cleanup
|
||||
conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smbConnection{
|
||||
conn: conn,
|
||||
session: session,
|
||||
share: share,
|
||||
lastUse: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getConnection gets a connection from the pool
|
||||
func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) {
|
||||
select {
|
||||
case conn := <-s.connPool:
|
||||
conn.lastUse = time.Now()
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection")
|
||||
}
|
||||
}
|
||||
|
||||
// returnConnection returns a connection to the pool
|
||||
func (s *SMBStorage) returnConnection(conn *smbConnection) {
|
||||
select {
|
||||
case s.connPool <- conn:
|
||||
default:
|
||||
// Pool is full, close the connection
|
||||
conn.close()
|
||||
}
|
||||
}
|
||||
|
||||
// close closes an SMB connection
|
||||
func (c *smbConnection) close() {
|
||||
if c.share != nil {
|
||||
_ = c.share.Umount() // #nosec G104 -- SMB cleanup
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Logoff() // #nosec G104 -- SMB cleanup
|
||||
}
|
||||
if c.conn != nil {
|
||||
c.conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a file from SMB share
|
||||
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
// Open file
|
||||
file, err := conn.share.Open(path)
|
||||
if err != nil {
|
||||
s.returnConnection(conn)
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
|
||||
}
|
||||
|
||||
// Read entire file into memory and close SMB connection
|
||||
// This is necessary because we need to return the connection to the pool
|
||||
data, err := io.ReadAll(file)
|
||||
file.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
s.returnConnection(conn)
|
||||
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("smb", "get", "success")
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
// Put stores a file on SMB share
|
||||
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
// Create directory structure
|
||||
if err := conn.share.MkdirAll(dir, 0755); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
|
||||
}
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New() // #nosec G401 -- MD5 used for file integrity check, not cryptographic security
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("smb", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Create temp file for atomic write
|
||||
tempPath := path + ".tmp"
|
||||
file, err := conn.share.Create(tempPath)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file")
|
||||
}
|
||||
|
||||
// Write data
|
||||
_, err = io.Copy(file, bytes.NewReader(buf.Bytes()))
|
||||
file.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if err != nil {
|
||||
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := conn.share.Rename(tempPath, path); err != nil {
|
||||
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "put", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from SMB share
|
||||
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
// Get size before deletion
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "delete", "not_found")
|
||||
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
|
||||
size := info.Size()
|
||||
|
||||
if err := conn.share.Remove(path); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "delete", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists on SMB share
|
||||
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
_, err = conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check SMB file existence")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix on SMB share
|
||||
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
searchPath := s.keyToPath(prefix)
|
||||
var objects []storage.StorageObject
|
||||
|
||||
err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := s.pathToKey(path)
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from SMB share
|
||||
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
|
||||
return &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks SMB health
|
||||
func (s *SMBStorage) Health(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
// Try to stat the base path
|
||||
path := s.keyToPath("")
|
||||
_, err = conn.share.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (s *SMBStorage) Close() error {
|
||||
close(s.connPool)
|
||||
for conn := range s.connPool {
|
||||
conn.close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// keyToPath converts a storage key to SMB path
|
||||
func (s *SMBStorage) keyToPath(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
key = filepath.Clean(key)
|
||||
|
||||
// Remove path traversal attempts
|
||||
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = strings.TrimPrefix(key, "../")
|
||||
key = strings.TrimPrefix(key, "..\\")
|
||||
}
|
||||
|
||||
key = filepath.Clean(key)
|
||||
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = ""
|
||||
}
|
||||
|
||||
if s.basePath != "" {
|
||||
return filepath.Join(s.basePath, key)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// pathToKey converts an SMB path back to a storage key
|
||||
func (s *SMBStorage) pathToKey(path string) string {
|
||||
if s.basePath != "" {
|
||||
path = strings.TrimPrefix(path, s.basePath)
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
path = strings.TrimPrefix(path, "\\")
|
||||
}
|
||||
return filepath.ToSlash(path)
|
||||
}
|
||||
|
||||
// walkPath recursively walks an SMB directory
|
||||
func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error {
|
||||
info, err := share.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fn(path, info)
|
||||
}
|
||||
|
||||
entries, err := share.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(path, entry.Name())
|
||||
if entry.IsDir() {
|
||||
if err := s.walkPath(share, entryPath, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := fn(entryPath, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateUsage calculates current SMB storage usage
|
||||
func (s *SMBStorage) calculateUsage(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
var total int64
|
||||
basePath := s.keyToPath("")
|
||||
|
||||
err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error {
|
||||
if !info.IsDir() {
|
||||
total += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("smb", total)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// UUID represents a UUID (RFC 4122)
|
||||
type UUID [16]byte
|
||||
|
||||
// New generates a random UUID v4
|
||||
func New() UUID {
|
||||
var u UUID
|
||||
// Read random bytes
|
||||
if _, err := rand.Read(u[:]); err != nil {
|
||||
panic(fmt.Sprintf("failed to generate UUID: %v", err))
|
||||
}
|
||||
|
||||
// Set version (4) and variant (RFC 4122)
|
||||
u[6] = (u[6] & 0x0f) | 0x40 // Version 4
|
||||
u[8] = (u[8] & 0x3f) | 0x80 // Variant RFC 4122
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// String returns the UUID in standard format (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)
|
||||
func (u UUID) String() string {
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
u[0:4], u[4:6], u[6:8], u[8:10], u[10:16])
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
package uuid
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNew tests UUID generation
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
runs int
|
||||
}{
|
||||
{
|
||||
name: "generate single UUID",
|
||||
runs: 1,
|
||||
},
|
||||
{
|
||||
name: "generate multiple UUIDs",
|
||||
runs: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for i := 0; i < tt.runs; i++ {
|
||||
uuid := New()
|
||||
|
||||
// Verify UUID is 16 bytes
|
||||
assert.Equal(t, 16, len(uuid))
|
||||
|
||||
// Verify version is 4
|
||||
version := (uuid[6] >> 4) & 0x0f
|
||||
assert.Equal(t, uint8(4), version, "UUID version should be 4")
|
||||
|
||||
// Verify variant is RFC 4122
|
||||
variant := (uuid[8] >> 6) & 0x03
|
||||
assert.Equal(t, uint8(2), variant, "UUID variant should be RFC 4122 (10 in binary)")
|
||||
|
||||
// Check uniqueness
|
||||
str := uuid.String()
|
||||
assert.False(t, seen[str], "UUID should be unique")
|
||||
seen[str] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestString tests UUID string formatting
|
||||
func TestString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid UUID
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero UUID",
|
||||
uuid: UUID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
expected: "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
{
|
||||
name: "all ones UUID",
|
||||
uuid: UUID{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
|
||||
expected: "ffffffff-ffff-ffff-ffff-ffffffffffff",
|
||||
},
|
||||
{
|
||||
name: "mixed values UUID",
|
||||
uuid: UUID{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88},
|
||||
expected: "12345678-9abc-def0-1122-334455667788",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
str := tt.uuid.String()
|
||||
assert.Equal(t, tt.expected, str)
|
||||
|
||||
// Verify format matches UUID regex
|
||||
matched, err := regexp.MatchString(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, str)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, matched, "UUID string should match standard format")
|
||||
|
||||
// Verify dashes are in correct positions
|
||||
assert.Equal(t, "-", string(str[8]))
|
||||
assert.Equal(t, "-", string(str[13]))
|
||||
assert.Equal(t, "-", string(str[18]))
|
||||
assert.Equal(t, "-", string(str[23]))
|
||||
|
||||
// Verify length
|
||||
assert.Equal(t, 36, len(str))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUUIDFormat tests that generated UUIDs match the standard format
|
||||
func TestUUIDFormat(t *testing.T) {
|
||||
const iterations = 1000
|
||||
|
||||
// Compile regex once for performance
|
||||
hexPattern := regexp.MustCompile(`^[0-9a-f]+$`)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
uuid := New()
|
||||
str := uuid.String()
|
||||
|
||||
// Test standard UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
|
||||
parts := strings.Split(str, "-")
|
||||
require.Len(t, parts, 5, "UUID should have 5 parts separated by dashes")
|
||||
assert.Equal(t, 8, len(parts[0]), "First part should be 8 characters")
|
||||
assert.Equal(t, 4, len(parts[1]), "Second part should be 4 characters")
|
||||
assert.Equal(t, 4, len(parts[2]), "Third part should be 4 characters")
|
||||
assert.Equal(t, 4, len(parts[3]), "Fourth part should be 4 characters")
|
||||
assert.Equal(t, 12, len(parts[4]), "Fifth part should be 12 characters")
|
||||
|
||||
// Verify all characters are hexadecimal
|
||||
for _, part := range parts {
|
||||
assert.True(t, hexPattern.MatchString(part), "UUID parts should only contain hex characters")
|
||||
}
|
||||
|
||||
// Verify version bits (4th character of third part should start with 4)
|
||||
versionChar := parts[2][0]
|
||||
assert.Equal(t, byte('4'), versionChar, "UUID version should be 4")
|
||||
|
||||
// Verify variant bits (first character of fourth part should be 8, 9, a, or b)
|
||||
variantChar := parts[3][0]
|
||||
assert.Contains(t, []byte{'8', '9', 'a', 'b'}, variantChar, "UUID variant should be RFC 4122")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentGeneration tests that UUID generation is safe for concurrent use
|
||||
func TestConcurrentGeneration(t *testing.T) {
|
||||
const numGoroutines = 100
|
||||
const uuidsPerGoroutine = 100
|
||||
|
||||
results := make(chan UUID, numGoroutines*uuidsPerGoroutine)
|
||||
|
||||
// Generate UUIDs concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
for j := 0; j < uuidsPerGoroutine; j++ {
|
||||
results <- New()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all UUIDs
|
||||
seen := make(map[string]bool)
|
||||
for i := 0; i < numGoroutines*uuidsPerGoroutine; i++ {
|
||||
uuid := <-results
|
||||
str := uuid.String()
|
||||
|
||||
// Verify uniqueness
|
||||
assert.False(t, seen[str], "UUID should be unique even in concurrent generation")
|
||||
seen[str] = true
|
||||
|
||||
// Verify version and variant
|
||||
version := (uuid[6] >> 4) & 0x0f
|
||||
assert.Equal(t, uint8(4), version)
|
||||
|
||||
variant := (uuid[8] >> 6) & 0x03
|
||||
assert.Equal(t, uint8(2), variant)
|
||||
}
|
||||
|
||||
// Verify we got all expected UUIDs
|
||||
assert.Equal(t, numGoroutines*uuidsPerGoroutine, len(seen))
|
||||
}
|
||||
|
||||
// TestUUIDEquality tests UUID equality
|
||||
func TestUUIDEquality(t *testing.T) {
|
||||
uuid1 := New()
|
||||
uuid2 := New()
|
||||
|
||||
// Different UUIDs should not be equal
|
||||
assert.NotEqual(t, uuid1, uuid2)
|
||||
assert.NotEqual(t, uuid1.String(), uuid2.String())
|
||||
|
||||
// Same UUID should be equal
|
||||
uuid3 := uuid1
|
||||
assert.Equal(t, uuid1, uuid3)
|
||||
assert.Equal(t, uuid1.String(), uuid3.String())
|
||||
}
|
||||
|
||||
// TestUUIDArrayAccess tests that UUID can be accessed as a byte array
|
||||
func TestUUIDArrayAccess(t *testing.T) {
|
||||
uuid := New()
|
||||
|
||||
// Verify we can access all bytes
|
||||
for i := 0; i < 16; i++ {
|
||||
_ = uuid[i]
|
||||
}
|
||||
|
||||
// Verify length
|
||||
assert.Equal(t, 16, len(uuid))
|
||||
}
|
||||
|
||||
// BenchmarkNew benchmarks UUID generation
|
||||
func BenchmarkNew(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = New()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkString benchmarks UUID string conversion
|
||||
func BenchmarkString(b *testing.B) {
|
||||
uuid := New()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = uuid.String()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package vcs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CredentialStore manages git credentials for different repository patterns
|
||||
type CredentialStore struct {
|
||||
credentials []CredentialEntry
|
||||
}
|
||||
|
||||
// CredentialEntry represents credentials for a specific pattern
|
||||
type CredentialEntry struct {
|
||||
Pattern string `json:"pattern"` // Glob pattern: "github.com/myorg/*"
|
||||
Host string `json:"host"` // Git host: "github.com"
|
||||
Username string `json:"username"` // Usually "oauth2" for tokens
|
||||
Token string `json:"token"` // Access token
|
||||
Fallback bool `json:"fallback"` // Use as fallback if no match
|
||||
}
|
||||
|
||||
// CredentialConfig represents the JSON configuration format
|
||||
type CredentialConfig struct {
|
||||
Credentials []CredentialEntry `json:"credentials"`
|
||||
}
|
||||
|
||||
// NewCredentialStore creates a new credential store
|
||||
func NewCredentialStore() *CredentialStore {
|
||||
return &CredentialStore{
|
||||
credentials: make([]CredentialEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromFile loads credentials from a JSON file
|
||||
func (cs *CredentialStore) LoadFromFile(path string) error {
|
||||
if path == "" {
|
||||
log.Debug().Msg("No credential file specified, using system git config")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
log.Warn().Str("path", path).Msg("Credential file not found, using system git config")
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) // #nosec G304 -- Path is from config, not user input
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read credential file: %w", err)
|
||||
}
|
||||
|
||||
var config CredentialConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("failed to parse credential file: %w", err)
|
||||
}
|
||||
|
||||
cs.credentials = config.Credentials
|
||||
|
||||
log.Info().
|
||||
Str("file", path).
|
||||
Int("credentials", len(cs.credentials)).
|
||||
Msg("Loaded git credentials from file")
|
||||
|
||||
// Log patterns (not tokens!) for debugging
|
||||
for i, cred := range cs.credentials {
|
||||
log.Debug().
|
||||
Int("index", i).
|
||||
Str("pattern", cred.Pattern).
|
||||
Str("host", cred.Host).
|
||||
Bool("fallback", cred.Fallback).
|
||||
Msg("Registered credential pattern")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredentialsForModule finds the best matching credentials for a module path
|
||||
// Returns (username, token, found)
|
||||
func (cs *CredentialStore) GetCredentialsForModule(modulePath string) (string, string, bool) {
|
||||
if len(cs.credentials) == 0 {
|
||||
// No credentials configured, rely on system git config
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Find best match
|
||||
var bestMatch *CredentialEntry
|
||||
var fallbackMatch *CredentialEntry
|
||||
bestMatchLen := 0
|
||||
|
||||
for i := range cs.credentials {
|
||||
cred := &cs.credentials[i]
|
||||
|
||||
// Check for fallback
|
||||
if cred.Fallback {
|
||||
fallbackMatch = cred
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if pattern matches
|
||||
if cs.matchPattern(cred.Pattern, modulePath) {
|
||||
// Use longest matching pattern (most specific)
|
||||
if len(cred.Pattern) > bestMatchLen {
|
||||
bestMatch = cred
|
||||
bestMatchLen = len(cred.Pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use best match if found
|
||||
if bestMatch != nil {
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("pattern", bestMatch.Pattern).
|
||||
Str("host", bestMatch.Host).
|
||||
Msg("Matched credential pattern")
|
||||
return bestMatch.Username, bestMatch.Token, true
|
||||
}
|
||||
|
||||
// Use fallback if available
|
||||
if fallbackMatch != nil {
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("pattern", fallbackMatch.Pattern).
|
||||
Msg("Using fallback credentials")
|
||||
return fallbackMatch.Username, fallbackMatch.Token, true
|
||||
}
|
||||
|
||||
// No match found
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Msg("No credential pattern matched, using system git config")
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// matchPattern checks if a module path matches a credential pattern
|
||||
// Supports glob-style patterns:
|
||||
// - github.com/myorg/* matches github.com/myorg/repo1, github.com/myorg/repo2
|
||||
// - github.com/myorg/repo matches exactly github.com/myorg/repo
|
||||
// - * matches everything
|
||||
func (cs *CredentialStore) matchPattern(pattern, modulePath string) bool {
|
||||
// Exact match
|
||||
if pattern == modulePath {
|
||||
return true
|
||||
}
|
||||
|
||||
// Wildcard match all
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Glob-style matching
|
||||
matched, err := filepath.Match(pattern, modulePath)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("pattern", pattern).Msg("Invalid pattern")
|
||||
return false
|
||||
}
|
||||
|
||||
if matched {
|
||||
return true
|
||||
}
|
||||
|
||||
// Prefix matching with /*
|
||||
if strings.HasSuffix(pattern, "/*") {
|
||||
prefix := strings.TrimSuffix(pattern, "/*")
|
||||
return strings.HasPrefix(modulePath, prefix+"/")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateNetrcContent creates .netrc file content for a specific host
|
||||
func (cs *CredentialStore) CreateNetrcContent(host, username, token string) string {
|
||||
return fmt.Sprintf("machine %s\nlogin %s\npassword %s\n", host, username, token)
|
||||
}
|
||||
|
||||
// GetCredentialsForHost finds credentials for a specific git host (e.g., "github.com")
|
||||
// This is useful when you need credentials for a host but don't have a full module path
|
||||
func (cs *CredentialStore) GetCredentialsForHost(host string) (string, string, bool) {
|
||||
if len(cs.credentials) == 0 {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Look for exact host match first
|
||||
for i := range cs.credentials {
|
||||
cred := &cs.credentials[i]
|
||||
if cred.Host == host && !cred.Fallback {
|
||||
log.Debug().
|
||||
Str("host", host).
|
||||
Str("pattern", cred.Pattern).
|
||||
Msg("Found credentials for host")
|
||||
return cred.Username, cred.Token, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try fallback
|
||||
for i := range cs.credentials {
|
||||
cred := &cs.credentials[i]
|
||||
if cred.Fallback {
|
||||
log.Debug().
|
||||
Str("host", host).
|
||||
Msg("Using fallback credentials for host")
|
||||
return cred.Username, cred.Token, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// ValidateConfig validates the credential configuration
|
||||
func (cs *CredentialStore) ValidateConfig() error {
|
||||
hostPatterns := make(map[string]bool)
|
||||
|
||||
for i, cred := range cs.credentials {
|
||||
// Check required fields
|
||||
if cred.Pattern == "" {
|
||||
return fmt.Errorf("credential entry %d: pattern is required", i)
|
||||
}
|
||||
if cred.Host == "" && cred.Pattern != "*" {
|
||||
return fmt.Errorf("credential entry %d: host is required (pattern: %s)", i, cred.Pattern)
|
||||
}
|
||||
if cred.Token == "" {
|
||||
return fmt.Errorf("credential entry %d: token is required (pattern: %s)", i, cred.Pattern)
|
||||
}
|
||||
|
||||
// Set default username if not provided
|
||||
if cred.Username == "" {
|
||||
cs.credentials[i].Username = "oauth2"
|
||||
}
|
||||
|
||||
// Check for duplicate patterns
|
||||
key := cred.Pattern + ":" + cred.Host
|
||||
if hostPatterns[key] && !cred.Fallback {
|
||||
log.Warn().
|
||||
Str("pattern", cred.Pattern).
|
||||
Str("host", cred.Host).
|
||||
Msg("Duplicate credential pattern, last one wins")
|
||||
}
|
||||
hostPatterns[key] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+280
@@ -0,0 +1,280 @@
|
||||
package vcs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GitFetcher handles git repository operations
|
||||
type GitFetcher struct {
|
||||
workDir string
|
||||
timeout time.Duration
|
||||
credStore *CredentialStore
|
||||
}
|
||||
|
||||
// NewGitFetcher creates a new git fetcher
|
||||
func NewGitFetcher(workDir string, credStore *CredentialStore) *GitFetcher {
|
||||
if workDir == "" {
|
||||
workDir = os.TempDir()
|
||||
}
|
||||
|
||||
if credStore == nil {
|
||||
credStore = NewCredentialStore()
|
||||
}
|
||||
|
||||
return &GitFetcher{
|
||||
workDir: workDir,
|
||||
timeout: 30 * time.Second,
|
||||
credStore: credStore,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchModule clones a git repository and checks out a specific version
|
||||
// Returns the path to the checked-out source directory
|
||||
func (g *GitFetcher) FetchModule(ctx context.Context, modulePath, version, credentials string) (string, error) {
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, g.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Parse module path to extract repository URL
|
||||
repoURL, err := g.modulePathToRepoURL(modulePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create temporary directory for this clone
|
||||
cloneDir, err := os.MkdirTemp(g.workDir, "gohoarder-git-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("repo_url", repoURL).
|
||||
Str("clone_dir", cloneDir).
|
||||
Msg("Fetching module from git")
|
||||
|
||||
// Set up credentials
|
||||
credentialHelper, cleanup, err := g.setupCredentials(repoURL, modulePath, credentials)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(cloneDir) // #nosec G104 -- Cleanup
|
||||
return "", fmt.Errorf("failed to setup credentials: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Try shallow clone with specific version first (fastest)
|
||||
if err := g.shallowClone(ctx, repoURL, version, cloneDir, credentialHelper); err != nil {
|
||||
log.Debug().Err(err).Msg("Shallow clone failed, trying full clone")
|
||||
|
||||
// Fallback to full clone
|
||||
if err := g.fullClone(ctx, repoURL, cloneDir, credentialHelper); err != nil {
|
||||
_ = os.RemoveAll(cloneDir) // #nosec G104 -- Cleanup
|
||||
return "", fmt.Errorf("git clone failed: %w", err)
|
||||
}
|
||||
|
||||
// Checkout specific version
|
||||
if err := g.checkout(ctx, cloneDir, version); err != nil {
|
||||
_ = os.RemoveAll(cloneDir) // #nosec G104 -- Cleanup
|
||||
return "", fmt.Errorf("git checkout failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Str("path", cloneDir).
|
||||
Msg("Successfully fetched module from git")
|
||||
|
||||
return cloneDir, nil
|
||||
}
|
||||
|
||||
// modulePathToRepoURL converts a Go module path to a git repository URL
|
||||
// Examples:
|
||||
//
|
||||
// github.com/user/repo → https://github.com/user/repo.git
|
||||
// gitlab.com/group/project → https://gitlab.com/group/project.git
|
||||
func (g *GitFetcher) modulePathToRepoURL(modulePath string) (string, error) {
|
||||
// Remove any path components after the repository
|
||||
// e.g., github.com/user/repo/v2 → github.com/user/repo
|
||||
parts := strings.Split(modulePath, "/")
|
||||
if len(parts) < 3 {
|
||||
return "", fmt.Errorf("invalid module path: %s", modulePath)
|
||||
}
|
||||
|
||||
// For github.com, gitlab.com, bitbucket.org, etc.
|
||||
// Format: host/owner/repo
|
||||
host := parts[0]
|
||||
owner := parts[1]
|
||||
repo := parts[2]
|
||||
|
||||
// Remove version suffix if present (e.g., /v2, /v3)
|
||||
repo = strings.TrimPrefix(repo, "v")
|
||||
|
||||
repoURL := fmt.Sprintf("https://%s/%s/%s.git", host, owner, repo)
|
||||
return repoURL, nil
|
||||
}
|
||||
|
||||
// setupCredentials configures git credentials for authentication
|
||||
// Returns credential helper configuration and cleanup function
|
||||
func (g *GitFetcher) setupCredentials(repoURL, modulePath, credentials string) (map[string]string, func(), error) {
|
||||
env := make(map[string]string)
|
||||
cleanup := func() {}
|
||||
|
||||
// Priority 1: Check credential store for pattern-based credentials
|
||||
if g.credStore != nil {
|
||||
username, token, found := g.credStore.GetCredentialsForModule(modulePath)
|
||||
if found {
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Msg("Using credentials from credential store")
|
||||
return g.createTempNetrc(repoURL, username, token)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Use credentials from HTTP Authorization header (if provided)
|
||||
if credentials != "" {
|
||||
log.Debug().Msg("Using credentials from Authorization header")
|
||||
return g.createTempNetrcFromHeader(repoURL, credentials)
|
||||
}
|
||||
|
||||
// Priority 3: Rely on system git config (.netrc, etc.)
|
||||
log.Debug().Msg("No credentials provided, using system git config")
|
||||
return env, cleanup, nil
|
||||
}
|
||||
|
||||
// createTempNetrc creates a temporary .netrc file with the provided credentials
|
||||
func (g *GitFetcher) createTempNetrc(repoURL, username, token string) (map[string]string, func(), error) {
|
||||
// Create temporary .netrc file
|
||||
tempDir, err := os.MkdirTemp("", "gohoarder-netrc-*")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create temp netrc directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract host from repo URL
|
||||
host := g.extractHost(repoURL)
|
||||
|
||||
// Create .netrc file
|
||||
netrcPath := filepath.Join(tempDir, ".netrc")
|
||||
netrcContent := fmt.Sprintf("machine %s\nlogin %s\npassword %s\n", host, username, token)
|
||||
if err := os.WriteFile(netrcPath, []byte(netrcContent), 0600); err != nil {
|
||||
_ = os.RemoveAll(tempDir) // #nosec G104 -- Cleanup
|
||||
return nil, nil, fmt.Errorf("failed to write .netrc: %w", err)
|
||||
}
|
||||
|
||||
env := map[string]string{
|
||||
"HOME": tempDir,
|
||||
"GIT_TERMINAL_PROMPT": "0",
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.RemoveAll(tempDir) // #nosec G104 -- Cleanup
|
||||
}
|
||||
|
||||
log.Debug().Str("host", host).Msg("Created temporary .netrc for git authentication")
|
||||
|
||||
return env, cleanup, nil
|
||||
}
|
||||
|
||||
// createTempNetrcFromHeader creates a temporary .netrc from Authorization header credentials
|
||||
func (g *GitFetcher) createTempNetrcFromHeader(repoURL, credentials string) (map[string]string, func(), error) {
|
||||
// Extract token from credentials
|
||||
token := strings.TrimPrefix(credentials, "Bearer ")
|
||||
token = strings.TrimPrefix(token, "Token ")
|
||||
token = strings.TrimPrefix(token, "Private-Token ")
|
||||
|
||||
if token == "" || token == credentials {
|
||||
// Not in expected format, rely on system config
|
||||
log.Debug().Msg("Credentials not in Bearer/Token format, using system git config")
|
||||
return make(map[string]string), func() {}, nil
|
||||
}
|
||||
|
||||
// Use oauth2 as default username for token-based auth
|
||||
return g.createTempNetrc(repoURL, "oauth2", token)
|
||||
}
|
||||
|
||||
// extractHost extracts the git host from a repository URL
|
||||
func (g *GitFetcher) extractHost(repoURL string) string {
|
||||
if strings.Contains(repoURL, "github.com") {
|
||||
return "github.com"
|
||||
}
|
||||
if strings.Contains(repoURL, "gitlab.com") {
|
||||
return "gitlab.com"
|
||||
}
|
||||
if strings.Contains(repoURL, "bitbucket.org") {
|
||||
return "bitbucket.org"
|
||||
}
|
||||
|
||||
// Generic extraction
|
||||
parts := strings.Split(repoURL, "/")
|
||||
if len(parts) >= 3 {
|
||||
return strings.TrimPrefix(parts[2], "//")
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// shallowClone performs a shallow clone of a specific version
|
||||
func (g *GitFetcher) shallowClone(ctx context.Context, repoURL, version, cloneDir string, credentialHelper map[string]string) error {
|
||||
cmd := exec.CommandContext(ctx, "git", "clone", "--depth", "1", "--branch", version, repoURL, cloneDir)
|
||||
cmd.Env = append(os.Environ(), g.envMapToSlice(credentialHelper)...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("shallow clone failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fullClone performs a full clone of the repository
|
||||
func (g *GitFetcher) fullClone(ctx context.Context, repoURL, cloneDir string, credentialHelper map[string]string) error {
|
||||
cmd := exec.CommandContext(ctx, "git", "clone", repoURL, cloneDir)
|
||||
cmd.Env = append(os.Environ(), g.envMapToSlice(credentialHelper)...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("full clone failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkout checks out a specific version (tag, branch, or commit)
|
||||
func (g *GitFetcher) checkout(ctx context.Context, repoDir, version string) error {
|
||||
cmd := exec.CommandContext(ctx, "git", "checkout", version)
|
||||
cmd.Dir = repoDir
|
||||
cmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("checkout failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// envMapToSlice converts environment map to slice
|
||||
func (g *GitFetcher) envMapToSlice(envMap map[string]string) []string {
|
||||
var env []string
|
||||
for k, v := range envMap {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Cleanup removes temporary directories
|
||||
func (g *GitFetcher) Cleanup(paths ...string) {
|
||||
for _, path := range paths {
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
log.Warn().Err(err).Str("path", path).Msg("Failed to cleanup temporary directory")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package vcs
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// ModuleBuilder builds Go module artifacts from source
|
||||
type ModuleBuilder struct{}
|
||||
|
||||
// NewModuleBuilder creates a new module builder
|
||||
func NewModuleBuilder() *ModuleBuilder {
|
||||
return &ModuleBuilder{}
|
||||
}
|
||||
|
||||
// ModuleInfo represents Go module version metadata (.info file)
|
||||
type ModuleInfo struct {
|
||||
Version string `json:"Version"`
|
||||
Time time.Time `json:"Time"`
|
||||
}
|
||||
|
||||
// BuildModuleZip creates a Go module zip from source directory
|
||||
// Follows the Go module zip format specification: https://go.dev/ref/mod#zip-files
|
||||
func (b *ModuleBuilder) BuildModuleZip(ctx context.Context, srcPath, modulePath, version string) (io.ReadCloser, error) {
|
||||
log.Debug().
|
||||
Str("src_path", srcPath).
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Msg("Building module zip")
|
||||
|
||||
// Create in-memory zip
|
||||
var buf bytes.Buffer
|
||||
zipWriter := zip.NewWriter(&buf)
|
||||
|
||||
// Collect all files to include in zip
|
||||
files, err := b.collectFiles(srcPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to collect files: %w", err)
|
||||
}
|
||||
|
||||
// Sort files for deterministic zip
|
||||
sort.Strings(files)
|
||||
|
||||
// Add files to zip with proper prefix
|
||||
prefix := fmt.Sprintf("%s@%s/", modulePath, version)
|
||||
for _, relPath := range files {
|
||||
if err := b.addFileToZip(zipWriter, srcPath, relPath, prefix); err != nil {
|
||||
zipWriter.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, fmt.Errorf("failed to add file %s: %w", relPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close zip writer: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("module", modulePath).
|
||||
Str("version", version).
|
||||
Int("files", len(files)).
|
||||
Int("size", buf.Len()).
|
||||
Msg("Successfully built module zip")
|
||||
|
||||
return io.NopCloser(bytes.NewReader(buf.Bytes())), nil
|
||||
}
|
||||
|
||||
// collectFiles walks the source directory and collects files to include
|
||||
func (b *ModuleBuilder) collectFiles(srcPath string) ([]string, error) {
|
||||
var files []string
|
||||
|
||||
err := filepath.Walk(srcPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
// Skip .git directory
|
||||
if info.Name() == ".git" {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Skip vendor directory (per Go module zip spec)
|
||||
if info.Name() == "vendor" {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get relative path
|
||||
relPath, err := filepath.Rel(srcPath, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip hidden files (except .gitignore, etc. if needed)
|
||||
if strings.HasPrefix(filepath.Base(relPath), ".") && relPath != ".gitignore" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Include file
|
||||
files = append(files, relPath)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// addFileToZip adds a single file to the zip archive
|
||||
func (b *ModuleBuilder) addFileToZip(zipWriter *zip.Writer, srcPath, relPath, prefix string) error {
|
||||
// Create zip header
|
||||
header := &zip.FileHeader{
|
||||
Name: prefix + filepath.ToSlash(relPath),
|
||||
Method: zip.Deflate,
|
||||
}
|
||||
|
||||
// Get file info for permissions
|
||||
fullPath := filepath.Join(srcPath, relPath)
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set modification time to a fixed value for deterministic zips
|
||||
// Go uses the timestamp from the version info
|
||||
header.Modified = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
header.SetMode(info.Mode())
|
||||
|
||||
// Create file in zip
|
||||
writer, err := zipWriter.CreateHeader(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy file contents
|
||||
file, err := os.Open(fullPath) // #nosec G304 -- Path is from zip archive extraction
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
if _, err := io.Copy(writer, file); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateModInfo creates .info file (JSON metadata)
|
||||
func (b *ModuleBuilder) GenerateModInfo(ctx context.Context, srcPath, version string) ([]byte, error) {
|
||||
// Get commit timestamp from git
|
||||
timestamp, err := b.getGitCommitTime(srcPath)
|
||||
if err != nil {
|
||||
// Fallback to current time if git info not available
|
||||
log.Warn().Err(err).Msg("Failed to get git commit time, using current time")
|
||||
timestamp = time.Now()
|
||||
}
|
||||
|
||||
info := ModuleInfo{
|
||||
Version: version,
|
||||
Time: timestamp,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal module info: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getGitCommitTime retrieves the commit timestamp from git
|
||||
func (b *ModuleBuilder) getGitCommitTime(repoPath string) (time.Time, error) {
|
||||
cmd := exec.Command("git", "log", "-1", "--format=%cI")
|
||||
cmd.Dir = repoPath
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
// Parse ISO 8601 timestamp
|
||||
timestamp, err := time.Parse(time.RFC3339, strings.TrimSpace(string(output)))
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
return timestamp, nil
|
||||
}
|
||||
|
||||
// ExtractGoMod extracts go.mod content
|
||||
func (b *ModuleBuilder) ExtractGoMod(ctx context.Context, srcPath string) ([]byte, error) {
|
||||
goModPath := filepath.Join(srcPath, "go.mod")
|
||||
|
||||
data, err := os.ReadFile(goModPath) // #nosec G304 -- Path is from controlled temp directory
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read go.mod: %w", err)
|
||||
}
|
||||
|
||||
// Validate go.mod (basic check)
|
||||
if !strings.Contains(string(data), "module ") {
|
||||
return nil, fmt.Errorf("invalid go.mod: missing module directive")
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ValidateModule performs basic validation on the module
|
||||
func (b *ModuleBuilder) ValidateModule(ctx context.Context, srcPath, expectedModulePath string) error {
|
||||
// Read go.mod
|
||||
goModData, err := b.ExtractGoMod(ctx, srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract module path from go.mod
|
||||
lines := strings.Split(string(goModData), "\n")
|
||||
var declaredModulePath string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "module ") {
|
||||
declaredModulePath = strings.TrimSpace(strings.TrimPrefix(line, "module "))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if declaredModulePath == "" {
|
||||
return fmt.Errorf("go.mod missing module declaration")
|
||||
}
|
||||
|
||||
// Check if module path matches (allow version suffixes)
|
||||
if !strings.HasPrefix(expectedModulePath, declaredModulePath) {
|
||||
return fmt.Errorf("module path mismatch: expected %s, got %s", expectedModulePath, declaredModulePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// EventType represents the type of event being broadcast
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
EventPackageCached EventType = "package_cached"
|
||||
EventPackageDeleted EventType = "package_deleted"
|
||||
EventPackageDownloaded EventType = "package_downloaded"
|
||||
EventScanComplete EventType = "scan_complete"
|
||||
EventStatsUpdate EventType = "stats_update"
|
||||
EventSystemAlert EventType = "system_alert"
|
||||
)
|
||||
|
||||
// Event represents a WebSocket event message
|
||||
type Event struct {
|
||||
Type EventType `json:"type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// Client represents a WebSocket client connection
|
||||
type Client struct {
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
server *Server
|
||||
subscriptions map[EventType]bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Server manages WebSocket connections and event broadcasting
|
||||
type Server struct {
|
||||
clients map[*Client]bool
|
||||
broadcast chan Event
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
mu sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// Config holds WebSocket server configuration
|
||||
type Config struct {
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
}
|
||||
|
||||
// NewServer creates a new WebSocket server
|
||||
func NewServer(cfg Config) *Server {
|
||||
if cfg.CheckOrigin == nil {
|
||||
cfg.CheckOrigin = func(r *http.Request) bool {
|
||||
return true // Allow all origins by default
|
||||
}
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
clients: make(map[*Client]bool),
|
||||
broadcast: make(chan Event, 256),
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: cfg.ReadBufferSize,
|
||||
WriteBufferSize: cfg.WriteBufferSize,
|
||||
CheckOrigin: cfg.CheckOrigin,
|
||||
},
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Start starts the WebSocket server event loop
|
||||
func (s *Server) Start(ctx context.Context) {
|
||||
go s.run(ctx)
|
||||
log.Info().Msg("WebSocket server started")
|
||||
}
|
||||
|
||||
// run handles client registration/unregistration and broadcasting
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msg("WebSocket server shutting down")
|
||||
s.closeAllClients()
|
||||
return
|
||||
|
||||
case client := <-s.register:
|
||||
s.mu.Lock()
|
||||
s.clients[client] = true
|
||||
s.mu.Unlock()
|
||||
log.Debug().
|
||||
Int("total_clients", len(s.clients)).
|
||||
Msg("Client registered")
|
||||
|
||||
case client := <-s.unregister:
|
||||
s.mu.Lock()
|
||||
if _, ok := s.clients[client]; ok {
|
||||
delete(s.clients, client)
|
||||
close(client.send)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
log.Debug().
|
||||
Int("total_clients", len(s.clients)).
|
||||
Msg("Client unregistered")
|
||||
|
||||
case event := <-s.broadcast:
|
||||
s.broadcastEvent(event)
|
||||
|
||||
case <-ticker.C:
|
||||
// Ping all clients to keep connections alive
|
||||
s.pingClients()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcastEvent sends an event to all subscribed clients
|
||||
func (s *Server) broadcastEvent(event Event) {
|
||||
message, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal event")
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for client := range s.clients {
|
||||
// Check if client is subscribed to this event type
|
||||
client.mu.RLock()
|
||||
subscribed := len(client.subscriptions) == 0 || client.subscriptions[event.Type]
|
||||
client.mu.RUnlock()
|
||||
|
||||
if subscribed {
|
||||
select {
|
||||
case client.send <- message:
|
||||
default:
|
||||
// Client send buffer full - close connection
|
||||
go func(c *Client) {
|
||||
s.unregister <- c
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("event_type", string(event.Type)).
|
||||
Int("clients_notified", len(s.clients)).
|
||||
Msg("Event broadcast")
|
||||
}
|
||||
|
||||
// pingClients sends ping messages to all connected clients
|
||||
func (s *Server) pingClients() {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for client := range s.clients {
|
||||
if err := client.conn.WriteControl(
|
||||
websocket.PingMessage,
|
||||
[]byte{},
|
||||
time.Now().Add(10*time.Second),
|
||||
); err != nil {
|
||||
log.Debug().Err(err).Msg("Failed to ping client")
|
||||
go func(c *Client) {
|
||||
s.unregister <- c
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeAllClients closes all client connections
|
||||
func (s *Server) closeAllClients() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for client := range s.clients {
|
||||
client.conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
close(client.send)
|
||||
}
|
||||
s.clients = make(map[*Client]bool)
|
||||
}
|
||||
|
||||
// Broadcast sends an event to all connected clients
|
||||
func (s *Server) Broadcast(eventType EventType, data map[string]interface{}) {
|
||||
event := Event{
|
||||
Type: eventType,
|
||||
Timestamp: time.Now(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.broadcast <- event:
|
||||
default:
|
||||
log.Warn().Msg("Broadcast channel full - dropping event")
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket upgrades HTTP connection to WebSocket
|
||||
func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to upgrade connection")
|
||||
return
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
conn: conn,
|
||||
send: make(chan []byte, 256),
|
||||
server: s,
|
||||
subscriptions: make(map[EventType]bool),
|
||||
}
|
||||
|
||||
s.register <- client
|
||||
|
||||
// Start goroutines for reading and writing
|
||||
go client.readPump()
|
||||
go client.writePump()
|
||||
|
||||
log.Info().
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Msg("WebSocket connection established")
|
||||
}
|
||||
|
||||
// readPump handles incoming messages from the client
|
||||
func (c *Client) readPump() {
|
||||
defer func() {
|
||||
c.server.unregister <- c
|
||||
c.conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}()
|
||||
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // #nosec G104 -- Websocket deadline
|
||||
c.conn.SetPongHandler(func(string) error { // #nosec G104 -- Websocket handler
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) // #nosec G104 -- Websocket deadline
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Error().Err(err).Msg("WebSocket read error")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Handle client messages (subscriptions, etc.)
|
||||
c.handleMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
// writePump handles outgoing messages to the client
|
||||
func (c *Client) writePump() {
|
||||
ticker := time.NewTicker(54 * time.Second)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) // #nosec G104 -- Websocket deadline, error not critical
|
||||
if !ok {
|
||||
// Channel closed
|
||||
_ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) // #nosec G104 -- Websocket write
|
||||
return
|
||||
}
|
||||
|
||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(message) // #nosec G104 -- Websocket buffer write
|
||||
|
||||
// Write any additional queued messages
|
||||
n := len(c.send)
|
||||
for i := 0; i < n; i++ {
|
||||
_, _ = w.Write([]byte{'\n'}) // #nosec G104 -- Websocket buffer write
|
||||
_, _ = w.Write(<-c.send) // #nosec G104 -- Websocket buffer write
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) // #nosec G104 -- Websocket deadline, error not critical
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage processes incoming client messages
|
||||
func (c *Client) handleMessage(message []byte) {
|
||||
var msg struct {
|
||||
Action string `json:"action"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to unmarshal client message")
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Action {
|
||||
case "subscribe":
|
||||
c.handleSubscribe(msg.Data)
|
||||
case "unsubscribe":
|
||||
c.handleUnsubscribe(msg.Data)
|
||||
case "ping":
|
||||
c.sendPong()
|
||||
default:
|
||||
log.Warn().Str("action", msg.Action).Msg("Unknown client action")
|
||||
}
|
||||
}
|
||||
|
||||
// handleSubscribe subscribes the client to specific event types
|
||||
func (c *Client) handleSubscribe(data interface{}) {
|
||||
eventTypes, ok := data.([]interface{})
|
||||
if !ok {
|
||||
log.Error().Msg("Invalid subscribe data format")
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, et := range eventTypes {
|
||||
if eventType, ok := et.(string); ok {
|
||||
c.subscriptions[EventType(eventType)] = true
|
||||
log.Debug().
|
||||
Str("event_type", eventType).
|
||||
Msg("Client subscribed to event type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleUnsubscribe unsubscribes the client from specific event types
|
||||
func (c *Client) handleUnsubscribe(data interface{}) {
|
||||
eventTypes, ok := data.([]interface{})
|
||||
if !ok {
|
||||
log.Error().Msg("Invalid unsubscribe data format")
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, et := range eventTypes {
|
||||
if eventType, ok := et.(string); ok {
|
||||
delete(c.subscriptions, EventType(eventType))
|
||||
log.Debug().
|
||||
Str("event_type", eventType).
|
||||
Msg("Client unsubscribed from event type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendPong sends a pong response to the client
|
||||
func (c *Client) sendPong() {
|
||||
response := map[string]string{"type": "pong"}
|
||||
message, _ := json.Marshal(response)
|
||||
select {
|
||||
case c.send <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectedClients returns the number of connected clients
|
||||
func (s *Server) GetConnectedClients() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.clients)
|
||||
}
|
||||
Reference in New Issue
Block a user