mirror of
https://github.com/lukaszraczylo/gohoarder.git
synced 2026-06-05 22:53:53 +00:00
refactor: reorganize struct fields, add new handlers and storage backends
- [x] Reorder struct fields across codebase for consistency - [x] Add analytics event handlers and tests - [x] Add authentication API key management handlers and tests - [x] Add pre-warming control handlers and tests - [x] Implement S3 storage backend with tests - [x] Implement SMB/CIFS storage backend with tests - [x] Add CDN middleware tests - [x] Integrate analytics tracking into cache manager - [x] Add S3 and SMB storage initialization in app setup - [x] Add CDN caching to proxy handlers - [x] Remove distributed locking (Redis lock manager) - [x] Remove proxy common package and utilities - [x] Remove standalone HTTP server package - [x] Remove logger middleware - [x] Simplify error handling utilities - [x] Update config with S3 and SMB options - [x] Update cache manager signature to include analytics
This commit is contained in:
@@ -10,23 +10,23 @@ import (
|
||||
|
||||
// PackageDownload represents a package download event
|
||||
type PackageDownload struct {
|
||||
Timestamp time.Time
|
||||
Registry string
|
||||
Name string
|
||||
Version string
|
||||
Timestamp time.Time
|
||||
BytesSize int64
|
||||
ClientIP string
|
||||
UserAgent string
|
||||
BytesSize int64
|
||||
}
|
||||
|
||||
// PackageStats holds statistics for a package
|
||||
type PackageStats struct {
|
||||
LastDownload time.Time
|
||||
FirstSeen time.Time
|
||||
Registry string
|
||||
Name string
|
||||
TotalDownloads int64
|
||||
UniqueVersions int
|
||||
LastDownload time.Time
|
||||
FirstSeen time.Time
|
||||
BytesServed int64
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ type PopularPackage struct {
|
||||
|
||||
// Engine tracks and analyzes package downloads
|
||||
type Engine struct {
|
||||
downloads []PackageDownload
|
||||
downloadsMu sync.RWMutex
|
||||
stats map[string]*PackageStats // key: registry:name
|
||||
statsMu sync.RWMutex
|
||||
maxEvents int
|
||||
stats map[string]*PackageStats
|
||||
flushTicker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
downloads []PackageDownload
|
||||
maxEvents int
|
||||
downloadsMu sync.RWMutex
|
||||
statsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config holds analytics engine configuration
|
||||
|
||||
+65
-21
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cdn"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/health"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/lock"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file"
|
||||
metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite"
|
||||
@@ -30,6 +29,8 @@ import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/scanner"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage/filesystem"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage/s3"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage/smb"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/vcs"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/websocket"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -50,7 +51,6 @@ type App struct {
|
||||
analyticsEngine *analytics.Engine
|
||||
wsServer *websocket.Server
|
||||
prewarmWorker *prewarming.Worker
|
||||
lockManager *lock.Manager
|
||||
cdnMiddleware *cdn.Middleware
|
||||
}
|
||||
|
||||
@@ -82,7 +82,33 @@ func (a *App) initializeComponents() error {
|
||||
switch a.config.Storage.Backend {
|
||||
case "filesystem":
|
||||
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
|
||||
case "s3":
|
||||
a.storage, err = s3.New(s3.Config{
|
||||
Region: a.config.Storage.S3.Region,
|
||||
Bucket: a.config.Storage.S3.Bucket,
|
||||
Prefix: a.config.Storage.S3.Prefix,
|
||||
AccessKeyID: a.config.Storage.S3.AccessKeyID,
|
||||
SecretAccessKey: a.config.Storage.S3.SecretAccessKey,
|
||||
Endpoint: a.config.Storage.S3.Endpoint,
|
||||
ForcePathStyle: a.config.Storage.S3.ForcePathStyle,
|
||||
MaxSizeBytes: a.config.Cache.MaxSizeBytes,
|
||||
})
|
||||
case "smb":
|
||||
a.storage, err = smb.New(smb.Config{
|
||||
Host: a.config.Storage.SMB.Host,
|
||||
Port: 445, // Default SMB port
|
||||
Share: a.config.Storage.SMB.Share,
|
||||
Path: a.config.Storage.Path,
|
||||
Username: a.config.Storage.SMB.Username,
|
||||
Password: a.config.Storage.SMB.Password,
|
||||
Domain: a.config.Storage.SMB.Domain,
|
||||
MaxSizeBytes: a.config.Cache.MaxSizeBytes,
|
||||
PoolSize: 5, // Default connection pool size
|
||||
})
|
||||
default:
|
||||
log.Warn().
|
||||
Str("backend", a.config.Storage.Backend).
|
||||
Msg("Unknown storage backend, defaulting to filesystem")
|
||||
a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -116,9 +142,16 @@ func (a *App) initializeComponents() error {
|
||||
return fmt.Errorf("failed to initialize scanner: %w", err)
|
||||
}
|
||||
|
||||
// Initialize cache manager with scanner
|
||||
// Initialize analytics engine first (needed by cache)
|
||||
log.Info().Msg("Initializing analytics engine")
|
||||
a.analyticsEngine = analytics.NewEngine(analytics.Config{
|
||||
MaxEvents: 10000,
|
||||
FlushInterval: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Initialize cache manager with scanner and analytics
|
||||
log.Info().Msg("Initializing cache manager")
|
||||
a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, cache.Config{
|
||||
a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, a.analyticsEngine, cache.Config{
|
||||
DefaultTTL: a.config.Cache.DefaultTTL,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
})
|
||||
@@ -153,13 +186,6 @@ func (a *App) initializeComponents() error {
|
||||
a.rescanWorker = scanner.NewRescanWorker(a.scanManager, a.metadata, a.storage, a.config.Security.RescanInterval)
|
||||
}
|
||||
|
||||
// Initialize analytics engine
|
||||
log.Info().Msg("Initializing analytics engine")
|
||||
a.analyticsEngine = analytics.NewEngine(analytics.Config{
|
||||
MaxEvents: 10000,
|
||||
FlushInterval: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Initialize WebSocket server
|
||||
log.Info().Msg("Initializing WebSocket server")
|
||||
a.wsServer = websocket.NewServer(websocket.Config{
|
||||
@@ -248,9 +274,28 @@ func (a *App) setupServer() error {
|
||||
a.app.Get("/api/stats/timeseries", a.handleTimeSeriesStats)
|
||||
a.app.Get("/api/info", a.handleInfo)
|
||||
|
||||
// Analytics endpoints
|
||||
a.app.Get("/api/analytics/top", a.handleAnalyticsTopPackages)
|
||||
a.app.Get("/api/analytics/trending", a.handleAnalyticsTrendingPackages)
|
||||
a.app.Get("/api/analytics/trends", a.handleAnalyticsTrends)
|
||||
a.app.Get("/api/analytics/total", a.handleAnalyticsTotalStats)
|
||||
a.app.Get("/api/analytics/registry/:registry", a.handleAnalyticsRegistryStats)
|
||||
a.app.Get("/api/analytics/package/:registry/:name", a.handleAnalyticsPackageStats)
|
||||
a.app.Get("/api/analytics/search", a.handleAnalyticsSearch)
|
||||
|
||||
// Admin endpoints (bypass management)
|
||||
a.app.All("/api/admin/bypasses/:id?", a.requireAdmin, a.handleAdminBypasses)
|
||||
|
||||
// Admin endpoints (pre-warming)
|
||||
a.app.Get("/api/admin/prewarming/status", a.requireAdmin, a.handlePrewarmingStatus)
|
||||
a.app.Post("/api/admin/prewarming/trigger", a.requireAdmin, a.handlePrewarmingTrigger)
|
||||
a.app.Post("/api/admin/prewarming/package", a.requireAdmin, a.handlePrewarmingPackage)
|
||||
|
||||
// Admin endpoints (API key management)
|
||||
a.app.Post("/api/admin/keys", a.requireAdmin, a.handleGenerateAPIKey)
|
||||
a.app.Get("/api/admin/keys", a.requireAdmin, a.handleListAPIKeys)
|
||||
a.app.Delete("/api/admin/keys/:key_id", a.requireAdmin, a.handleRevokeAPIKey)
|
||||
|
||||
// Proxy handlers (adapted from net/http)
|
||||
// Load git credentials if configured
|
||||
var credStore *vcs.CredentialStore
|
||||
@@ -270,22 +315,28 @@ func (a *App) setupServer() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Go proxy with CDN caching
|
||||
goProxyHandler := goproxy.New(a.cache, a.networkClient, goproxy.Config{
|
||||
Upstream: "https://proxy.golang.org",
|
||||
SumDBURL: "https://sum.golang.org",
|
||||
CredStore: credStore,
|
||||
})
|
||||
a.app.All("/go/*", adaptor.HTTPHandler(http.StripPrefix("/go", goProxyHandler)))
|
||||
goProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/go", goProxyHandler))
|
||||
a.app.All("/go/*", adaptor.HTTPHandler(goProxyWithCDN))
|
||||
|
||||
// NPM proxy with CDN caching
|
||||
npmProxyHandler := npm.New(a.cache, a.networkClient, npm.Config{
|
||||
Upstream: "https://registry.npmjs.org",
|
||||
})
|
||||
a.app.All("/npm/*", adaptor.HTTPHandler(http.StripPrefix("/npm", npmProxyHandler)))
|
||||
npmProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/npm", npmProxyHandler))
|
||||
a.app.All("/npm/*", adaptor.HTTPHandler(npmProxyWithCDN))
|
||||
|
||||
// PyPI proxy with CDN caching
|
||||
pypiProxyHandler := pypi.New(a.cache, a.networkClient, pypi.Config{
|
||||
Upstream: "https://pypi.org/simple",
|
||||
})
|
||||
a.app.All("/pypi/*", adaptor.HTTPHandler(http.StripPrefix("/pypi", pypiProxyHandler)))
|
||||
pypiProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/pypi", pypiProxyHandler))
|
||||
a.app.All("/pypi/*", adaptor.HTTPHandler(pypiProxyWithCDN))
|
||||
|
||||
// Serve frontend static files
|
||||
frontendDir := "frontend/dist"
|
||||
@@ -397,13 +448,6 @@ func (a *App) Shutdown() error {
|
||||
log.Error().Err(err).Msg("Error closing metadata store")
|
||||
}
|
||||
|
||||
// Close lock manager if initialized
|
||||
if a.lockManager != nil {
|
||||
if err := a.lockManager.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Error closing lock manager")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msg("Shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
+4
-5
@@ -500,11 +500,10 @@ func (a *App) handleInfo(c *fiber.Ctx) error {
|
||||
"max_cache_size": a.config.Cache.MaxSizeBytes,
|
||||
},
|
||||
"features": map[string]bool{
|
||||
"distributed_locking": a.lockManager != nil,
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"pre_warming": a.prewarmWorker != nil,
|
||||
"websockets": true,
|
||||
"analytics": true,
|
||||
"security_scanning": a.config.Security.Enabled,
|
||||
"pre_warming": a.prewarmWorker != nil,
|
||||
"websockets": true,
|
||||
"analytics": true,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -110,13 +110,13 @@ func (a *App) handleListBypasses(c *fiber.Ctx) error {
|
||||
|
||||
// CreateBypassRequest represents the request body for creating a bypass
|
||||
type CreateBypassRequest struct {
|
||||
Type metadata.BypassType `json:"type"` // "cve" or "package"
|
||||
Target string `json:"target"` // CVE ID or package name
|
||||
Reason string `json:"reason"` // Why this bypass is needed
|
||||
CreatedBy string `json:"created_by"` // Admin username
|
||||
ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration
|
||||
AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package
|
||||
NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired
|
||||
Type metadata.BypassType `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Reason string `json:"reason"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
AppliesTo string `json:"applies_to,omitempty"`
|
||||
ExpiresInHours int `json:"expires_in_hours"`
|
||||
NotifyOnExpiry bool `json:"notify_on_expiry"`
|
||||
}
|
||||
|
||||
// handleCreateBypass creates a new CVE bypass
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleAnalyticsTopPackages returns the most downloaded packages
|
||||
func (a *App) handleAnalyticsTopPackages(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get limit from query params (default: 10)
|
||||
limit := 10
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
|
||||
limit = parsedLimit
|
||||
}
|
||||
}
|
||||
|
||||
packages := a.analyticsEngine.GetTopPackages(limit)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"packages": packages,
|
||||
"total": len(packages),
|
||||
})
|
||||
}
|
||||
|
||||
// handleAnalyticsTrendingPackages returns trending packages
|
||||
func (a *App) handleAnalyticsTrendingPackages(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get limit from query params (default: 10)
|
||||
limit := 10
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
|
||||
limit = parsedLimit
|
||||
}
|
||||
}
|
||||
|
||||
packages := a.analyticsEngine.GetTrendingPackages(limit)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"packages": packages,
|
||||
"total": len(packages),
|
||||
})
|
||||
}
|
||||
|
||||
// handleAnalyticsTrends returns download trends over time
|
||||
func (a *App) handleAnalyticsTrends(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
trends := a.analyticsEngine.GetTrends()
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"trends": trends,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAnalyticsTotalStats returns overall statistics
|
||||
func (a *App) handleAnalyticsTotalStats(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
stats := a.analyticsEngine.GetTotalStats()
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
// handleAnalyticsRegistryStats returns per-registry statistics
|
||||
func (a *App) handleAnalyticsRegistryStats(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
registry := c.Params("registry")
|
||||
if registry == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "registry parameter is required",
|
||||
})
|
||||
}
|
||||
|
||||
stats := a.analyticsEngine.GetRegistryStats(registry)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(stats)
|
||||
}
|
||||
|
||||
// handleAnalyticsPackageStats returns statistics for a specific package
|
||||
func (a *App) handleAnalyticsPackageStats(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
registry := c.Params("registry")
|
||||
name := c.Params("name")
|
||||
|
||||
if registry == "" || name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "registry and name parameters are required",
|
||||
})
|
||||
}
|
||||
|
||||
stats, exists := a.analyticsEngine.GetPackageStats(registry, name)
|
||||
if !exists {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
"error": "package not found in analytics",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"package": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// handleAnalyticsSearch searches for packages matching a query
|
||||
func (a *App) handleAnalyticsSearch(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
c.Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
query := c.Query("q")
|
||||
if query == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "query parameter 'q' is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Get limit from query params (default: 20)
|
||||
limit := 20
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
|
||||
limit = parsedLimit
|
||||
}
|
||||
}
|
||||
|
||||
results := a.analyticsEngine.SearchPackages(query, limit)
|
||||
|
||||
log.Debug().
|
||||
Str("query", query).
|
||||
Int("results", len(results)).
|
||||
Msg("Analytics search completed")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"results": results,
|
||||
"total": len(results),
|
||||
"query": query,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AnalyticsHandlersTestSuite struct {
|
||||
suite.Suite
|
||||
app *fiber.App
|
||||
appInst *App
|
||||
engine *analytics.Engine
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) SetupTest() {
|
||||
// Create analytics engine
|
||||
s.engine = analytics.NewEngine(analytics.Config{
|
||||
MaxEvents: 10000,
|
||||
FlushInterval: 5 * time.Minute,
|
||||
})
|
||||
|
||||
// Seed some test data
|
||||
s.engine.TrackDownload(analytics.PackageDownload{
|
||||
Registry: "npm",
|
||||
Name: "lodash",
|
||||
Version: "4.17.21",
|
||||
Timestamp: time.Now(),
|
||||
BytesSize: 1024,
|
||||
})
|
||||
s.engine.TrackDownload(analytics.PackageDownload{
|
||||
Registry: "npm",
|
||||
Name: "react",
|
||||
Version: "18.0.0",
|
||||
Timestamp: time.Now(),
|
||||
BytesSize: 2048,
|
||||
})
|
||||
s.engine.TrackDownload(analytics.PackageDownload{
|
||||
Registry: "pypi",
|
||||
Name: "requests",
|
||||
Version: "2.28.0",
|
||||
Timestamp: time.Now(),
|
||||
BytesSize: 512,
|
||||
})
|
||||
|
||||
// Create app instance
|
||||
s.appInst = &App{
|
||||
analyticsEngine: s.engine,
|
||||
}
|
||||
|
||||
// Create Fiber app
|
||||
s.app = fiber.New()
|
||||
|
||||
// Register routes
|
||||
s.app.Get("/api/analytics/top", s.appInst.handleAnalyticsTopPackages)
|
||||
s.app.Get("/api/analytics/trending", s.appInst.handleAnalyticsTrendingPackages)
|
||||
s.app.Get("/api/analytics/trends", s.appInst.handleAnalyticsTrends)
|
||||
s.app.Get("/api/analytics/total", s.appInst.handleAnalyticsTotalStats)
|
||||
s.app.Get("/api/analytics/registry/:registry", s.appInst.handleAnalyticsRegistryStats)
|
||||
s.app.Get("/api/analytics/package/:registry/:name", s.appInst.handleAnalyticsPackageStats)
|
||||
s.app.Get("/api/analytics/search", s.appInst.handleAnalyticsSearch)
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TearDownTest() {
|
||||
if s.engine != nil {
|
||||
s.engine.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalyticsHandlersTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AnalyticsHandlersTestSuite))
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTopPackages() {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get top packages default",
|
||||
queryParams: "",
|
||||
expectedStatus: 200,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "get top packages with limit",
|
||||
queryParams: "?limit=5",
|
||||
expectedStatus: 200,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "get top packages with registry filter",
|
||||
queryParams: "?registry=npm",
|
||||
expectedStatus: 200,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/top"+tt.queryParams, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
|
||||
if !tt.expectError {
|
||||
var result struct {
|
||||
Packages []analytics.PackageStats `json:"packages"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTrendingPackages() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/trending", nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
Packages []analytics.PackageStats `json:"packages"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTrends() {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "get trends default timeframe",
|
||||
queryParams: "",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "get trends with registry filter",
|
||||
queryParams: "?registry=npm",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/trends"+tt.queryParams, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTotalStats() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/total", nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
TotalDownloads int64 `json:"total_downloads"`
|
||||
TotalBytes int64 `json:"total_bytes"`
|
||||
UniquePackages int `json:"unique_packages"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.Greater(result.TotalDownloads, int64(0))
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsRegistryStats() {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "npm registry stats",
|
||||
registry: "npm",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "pypi registry stats",
|
||||
registry: "pypi",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "go registry stats",
|
||||
registry: "go",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/registry/"+tt.registry, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsPackageStats() {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
packageName string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "lodash package stats",
|
||||
registry: "npm",
|
||||
packageName: "lodash",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "react package stats",
|
||||
registry: "npm",
|
||||
packageName: "react",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "requests package stats",
|
||||
registry: "pypi",
|
||||
packageName: "requests",
|
||||
expectedStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/package/"+tt.registry+"/"+tt.packageName, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsSearch() {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "search for lodash",
|
||||
queryParams: "?q=lodash",
|
||||
expectedStatus: 200,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "search for react",
|
||||
queryParams: "?q=react",
|
||||
expectedStatus: 200,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "search with no query",
|
||||
queryParams: "",
|
||||
expectedStatus: 400, // Query parameter is required
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("GET", "/api/analytics/search"+tt.queryParams, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
|
||||
if !tt.expectError {
|
||||
var result struct {
|
||||
Results []analytics.PackageStats `json:"results"`
|
||||
Total int `json:"total"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GenerateAPIKeyRequest represents a request to generate a new API key
|
||||
type GenerateAPIKeyRequest struct {
|
||||
ExpiresInMin *int `json:"expires_in_min"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// handleGenerateAPIKey generates a new API key
|
||||
func (a *App) handleGenerateAPIKey(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
var req GenerateAPIKeyRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid JSON in request body",
|
||||
})
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "name is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Parse role (default to readonly if not specified)
|
||||
var role auth.Role
|
||||
switch req.Role {
|
||||
case "admin":
|
||||
role = auth.RoleAdmin
|
||||
case "readwrite":
|
||||
role = auth.RoleReadWrite
|
||||
case "readonly", "":
|
||||
role = auth.RoleReadOnly
|
||||
default:
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid role, must be 'admin', 'readwrite', or 'readonly'",
|
||||
})
|
||||
}
|
||||
|
||||
// Calculate expiration
|
||||
var expiresIn *time.Duration
|
||||
if req.ExpiresInMin != nil {
|
||||
duration := time.Duration(*req.ExpiresInMin) * time.Minute
|
||||
expiresIn = &duration
|
||||
}
|
||||
|
||||
// Generate key
|
||||
apiKey, rawKey, err := a.authManager.GenerateAPIKey(req.Name, role, expiresIn)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("name", req.Name).Msg("Failed to generate API key")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": "failed to generate API key",
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("key_id", apiKey.ID).
|
||||
Str("name", apiKey.Name).
|
||||
Str("role", string(apiKey.Role)).
|
||||
Msg("API key generated")
|
||||
|
||||
// Return the key info and raw key (only time it's shown!)
|
||||
return c.Status(fiber.StatusCreated).JSON(fiber.Map{
|
||||
"key": rawKey, // IMPORTANT: This is the only time the raw key is shown
|
||||
"key_id": apiKey.ID,
|
||||
"name": apiKey.Name,
|
||||
"role": apiKey.Role,
|
||||
"expires": apiKey.ExpiresAt,
|
||||
"message": "Save this key now! It will not be shown again.",
|
||||
})
|
||||
}
|
||||
|
||||
// handleListAPIKeys lists all API keys
|
||||
func (a *App) handleListAPIKeys(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
keys := a.authManager.ListAPIKeys()
|
||||
|
||||
// Convert to response format (excluding hashed keys)
|
||||
response := make([]fiber.Map, len(keys))
|
||||
for i, key := range keys {
|
||||
response[i] = fiber.Map{
|
||||
"id": key.ID,
|
||||
"name": key.Name,
|
||||
"role": key.Role,
|
||||
"created_at": key.CreatedAt,
|
||||
"expires_at": key.ExpiresAt,
|
||||
"last_used_at": key.LastUsedAt,
|
||||
"permissions": key.Permissions,
|
||||
}
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"keys": response,
|
||||
"total": len(response),
|
||||
})
|
||||
}
|
||||
|
||||
// handleRevokeAPIKey revokes an API key
|
||||
func (a *App) handleRevokeAPIKey(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
keyID := c.Params("key_id")
|
||||
if keyID == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "key_id parameter is required",
|
||||
})
|
||||
}
|
||||
|
||||
err := a.authManager.RevokeAPIKey(keyID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("key_id", keyID).Msg("Failed to revoke API key")
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
"error": "API key not found",
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().Str("key_id", keyID).Msg("API key revoked")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"message": "API key revoked successfully",
|
||||
"key_id": keyID,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/auth"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AuthHandlersTestSuite struct {
|
||||
suite.Suite
|
||||
app *fiber.App
|
||||
appInst *App
|
||||
authManager *auth.Manager
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) SetupTest() {
|
||||
// Create auth manager
|
||||
s.authManager = auth.New()
|
||||
|
||||
// Create app instance
|
||||
s.appInst = &App{
|
||||
authManager: s.authManager,
|
||||
}
|
||||
|
||||
// Create Fiber app
|
||||
s.app = fiber.New()
|
||||
|
||||
// Register routes
|
||||
s.app.Post("/api/admin/keys", s.appInst.handleGenerateAPIKey)
|
||||
s.app.Get("/api/admin/keys", s.appInst.handleListAPIKeys)
|
||||
s.app.Delete("/api/admin/keys/:key_id", s.appInst.handleRevokeAPIKey)
|
||||
}
|
||||
|
||||
func TestAuthHandlersTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthHandlersTestSuite))
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) TestHandleGenerateAPIKey() {
|
||||
tests := []struct {
|
||||
requestBody map[string]string
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedRole string
|
||||
expectKey bool
|
||||
}{
|
||||
{
|
||||
name: "generate read-only key",
|
||||
requestBody: map[string]string{
|
||||
"role": "readonly",
|
||||
"name": "test-readonly-key",
|
||||
},
|
||||
expectedStatus: 201,
|
||||
expectedRole: "readonly",
|
||||
expectKey: true,
|
||||
},
|
||||
{
|
||||
name: "generate read-write key",
|
||||
requestBody: map[string]string{
|
||||
"role": "readwrite",
|
||||
"name": "test-readwrite-key",
|
||||
},
|
||||
expectedStatus: 201,
|
||||
expectedRole: "readwrite",
|
||||
expectKey: true,
|
||||
},
|
||||
{
|
||||
name: "generate admin key",
|
||||
requestBody: map[string]string{
|
||||
"role": "admin",
|
||||
"name": "test-admin-key",
|
||||
},
|
||||
expectedStatus: 201,
|
||||
expectedRole: "admin",
|
||||
expectKey: true,
|
||||
},
|
||||
{
|
||||
name: "invalid role",
|
||||
requestBody: map[string]string{
|
||||
"role": "invalid-role",
|
||||
"name": "test-key",
|
||||
},
|
||||
expectedStatus: 400,
|
||||
expectKey: false,
|
||||
},
|
||||
{
|
||||
name: "missing role defaults to readonly",
|
||||
requestBody: map[string]string{
|
||||
"name": "test-key-default-role",
|
||||
},
|
||||
expectedStatus: 201,
|
||||
expectedRole: "readonly", // Role defaults to readonly when not specified
|
||||
expectKey: true,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
requestBody: map[string]string{
|
||||
"role": "read-only",
|
||||
},
|
||||
expectedStatus: 400,
|
||||
expectKey: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
bodyBytes, err := json.Marshal(tt.requestBody)
|
||||
s.Require().NoError(err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
|
||||
if tt.expectKey {
|
||||
var result struct {
|
||||
Key string `json:"key"`
|
||||
KeyID string `json:"key_id"`
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(result.Key)
|
||||
s.NotEmpty(result.KeyID)
|
||||
s.Equal(tt.expectedRole, result.Role)
|
||||
s.Equal(tt.requestBody["name"], result.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) TestHandleListAPIKeys() {
|
||||
// Generate some test keys first
|
||||
s.authManager.GenerateAPIKey("test-key-1", auth.RoleReadOnly, nil)
|
||||
s.authManager.GenerateAPIKey("test-key-2", auth.RoleReadWrite, nil)
|
||||
s.authManager.GenerateAPIKey("test-key-3", auth.RoleAdmin, nil)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/admin/keys", nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
Keys []map[string]interface{} `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.GreaterOrEqual(result.Total, 3)
|
||||
|
||||
// Verify keys don't include the actual key value
|
||||
for _, key := range result.Keys {
|
||||
s.NotEmpty(key["id"])
|
||||
s.NotEmpty(key["role"])
|
||||
s.NotEmpty(key["name"])
|
||||
s.NotEmpty(key["created_at"])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) TestHandleRevokeAPIKey() {
|
||||
// Generate a test key
|
||||
keyInfo, _, _ := s.authManager.GenerateAPIKey("test-revoke-key", auth.RoleReadOnly, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "revoke existing key",
|
||||
keyID: keyInfo.ID,
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "revoke non-existent key",
|
||||
keyID: "non-existent-key-id",
|
||||
expectedStatus: 404,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
req := httptest.NewRequest("DELETE", "/api/admin/keys/"+tt.keyID, nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
|
||||
if tt.expectedStatus == 200 {
|
||||
var result map[string]string
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.Contains(result, "message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) TestHandleGenerateAPIKeyInvalidJSON() {
|
||||
req := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func (s *AuthHandlersTestSuite) TestGenerateAndRevokeKeyFlow() {
|
||||
// Generate a key
|
||||
bodyBytes, _ := json.Marshal(map[string]string{
|
||||
"role": "readonly",
|
||||
"name": "integration-test-key",
|
||||
})
|
||||
|
||||
req1 := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader(bodyBytes))
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
resp1, err := s.app.Test(req1)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(201, resp1.StatusCode)
|
||||
|
||||
var createResult struct {
|
||||
Key string `json:"key"`
|
||||
KeyID string `json:"key_id"`
|
||||
}
|
||||
err = json.NewDecoder(resp1.Body).Decode(&createResult)
|
||||
s.Require().NoError(err)
|
||||
keyID := createResult.KeyID
|
||||
|
||||
// List keys - should include our new key
|
||||
req2 := httptest.NewRequest("GET", "/api/admin/keys", nil)
|
||||
resp2, err := s.app.Test(req2)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp2.StatusCode)
|
||||
|
||||
var listResult struct {
|
||||
Keys []map[string]interface{} `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
err = json.NewDecoder(resp2.Body).Decode(&listResult)
|
||||
s.Require().NoError(err)
|
||||
|
||||
found := false
|
||||
for _, key := range listResult.Keys {
|
||||
if key["id"].(string) == keyID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(found, "newly created key should be in the list")
|
||||
|
||||
// Revoke the key
|
||||
req3 := httptest.NewRequest("DELETE", "/api/admin/keys/"+keyID, nil)
|
||||
resp3, err := s.app.Test(req3)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp3.StatusCode)
|
||||
|
||||
// List keys again - should not include the revoked key
|
||||
req4 := httptest.NewRequest("GET", "/api/admin/keys", nil)
|
||||
resp4, err := s.app.Test(req4)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp4.StatusCode)
|
||||
|
||||
var listResult2 struct {
|
||||
Keys []map[string]interface{} `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
err = json.NewDecoder(resp4.Body).Decode(&listResult2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
found = false
|
||||
for _, key := range listResult2.Keys {
|
||||
if key["id"].(string) == keyID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.False(found, "revoked key should not be in the list")
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handlePrewarmingStatus returns the status of the pre-warming worker
|
||||
func (a *App) handlePrewarmingStatus(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
status := a.prewarmWorker.GetStatus()
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(status)
|
||||
}
|
||||
|
||||
// handlePrewarmingTrigger manually triggers a pre-warming cycle
|
||||
func (a *App) handlePrewarmingTrigger(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
ctx := c.Context()
|
||||
a.prewarmWorker.TriggerPrewarm(ctx)
|
||||
|
||||
log.Info().Msg("Pre-warming manually triggered via API")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"message": "Pre-warming cycle triggered successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// PrewarmPackageRequest represents a request to pre-warm a specific package
|
||||
type PrewarmPackageRequest struct {
|
||||
Registry string `json:"registry"`
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// handlePrewarmingPackage pre-warms a specific package
|
||||
func (a *App) handlePrewarmingPackage(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
|
||||
var req PrewarmPackageRequest
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "invalid JSON in request body",
|
||||
})
|
||||
}
|
||||
|
||||
// Validate request
|
||||
if req.Registry == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "registry is required",
|
||||
})
|
||||
}
|
||||
if req.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "name is required",
|
||||
})
|
||||
}
|
||||
if req.Version == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "version is required",
|
||||
})
|
||||
}
|
||||
|
||||
ctx := c.Context()
|
||||
err := a.prewarmWorker.PrewarmPackage(ctx, req.Registry, req.Name, req.Version)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("registry", req.Registry).
|
||||
Str("name", req.Name).
|
||||
Str("version", req.Version).
|
||||
Msg("Failed to pre-warm package")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": "failed to pre-warm package",
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("registry", req.Registry).
|
||||
Str("name", req.Name).
|
||||
Str("version", req.Version).
|
||||
Msg("Package pre-warmed via API")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{
|
||||
"message": "Package pre-warmed successfully",
|
||||
"package": fiber.Map{
|
||||
"registry": req.Registry,
|
||||
"name": req.Name,
|
||||
"version": req.Version,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/prewarming"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PrewarmingHandlersTestSuite struct {
|
||||
suite.Suite
|
||||
app *fiber.App
|
||||
appInst *App
|
||||
prewarmWorker *prewarming.Worker
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) SetupTest() {
|
||||
// Create pre-warming worker (disabled by default)
|
||||
s.prewarmWorker = prewarming.NewWorker(prewarming.Config{
|
||||
Enabled: false,
|
||||
MaxConcurrent: 5,
|
||||
})
|
||||
|
||||
// Create app instance
|
||||
s.appInst = &App{
|
||||
prewarmWorker: s.prewarmWorker,
|
||||
}
|
||||
|
||||
// Create Fiber app
|
||||
s.app = fiber.New()
|
||||
|
||||
// Register routes
|
||||
s.app.Get("/api/admin/prewarming/status", s.appInst.handlePrewarmingStatus)
|
||||
s.app.Post("/api/admin/prewarming/trigger", s.appInst.handlePrewarmingTrigger)
|
||||
s.app.Post("/api/admin/prewarming/package", s.appInst.handlePrewarmingPackage)
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) TearDownTest() {
|
||||
if s.prewarmWorker != nil {
|
||||
s.prewarmWorker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrewarmingHandlersTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(PrewarmingHandlersTestSuite))
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingStatus() {
|
||||
req := httptest.NewRequest("GET", "/api/admin/prewarming/status", nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp.StatusCode)
|
||||
|
||||
var result struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Running bool `json:"running"`
|
||||
QueueSize int `json:"queue_size"`
|
||||
ActiveWorkers int `json:"active_workers"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.False(result.Enabled) // Disabled in test setup
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingTrigger() {
|
||||
req := httptest.NewRequest("POST", "/api/admin/prewarming/trigger", nil)
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(200, resp.StatusCode)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&result)
|
||||
s.NoError(err)
|
||||
s.Contains(result, "message")
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingPackage() {
|
||||
tests := []struct {
|
||||
requestBody map[string]string
|
||||
name string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "prewarm npm package",
|
||||
requestBody: map[string]string{
|
||||
"registry": "npm",
|
||||
"name": "lodash",
|
||||
"version": "4.17.21",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "prewarm pypi package",
|
||||
requestBody: map[string]string{
|
||||
"registry": "pypi",
|
||||
"name": "requests",
|
||||
"version": "2.28.0",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "prewarm go package",
|
||||
requestBody: map[string]string{
|
||||
"registry": "go",
|
||||
"name": "github.com/stretchr/testify",
|
||||
"version": "v1.8.0",
|
||||
},
|
||||
expectedStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "missing registry",
|
||||
requestBody: map[string]string{
|
||||
"name": "lodash",
|
||||
"version": "4.17.21",
|
||||
},
|
||||
expectedStatus: 400,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
requestBody: map[string]string{
|
||||
"registry": "npm",
|
||||
"version": "4.17.21",
|
||||
},
|
||||
expectedStatus: 400,
|
||||
},
|
||||
{
|
||||
name: "missing version",
|
||||
requestBody: map[string]string{
|
||||
"registry": "npm",
|
||||
"name": "lodash",
|
||||
},
|
||||
expectedStatus: 400,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
bodyBytes, err := json.Marshal(tt.requestBody)
|
||||
s.Require().NoError(err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/admin/prewarming/package", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedStatus, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingPackageInvalidJSON() {
|
||||
req := httptest.NewRequest("POST", "/api/admin/prewarming/package", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.app.Test(req)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(400, resp.StatusCode)
|
||||
}
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
|
||||
// ValidationResult represents a cached credential validation result
|
||||
type ValidationResult struct {
|
||||
Allowed bool
|
||||
ExpiresAt time.Time
|
||||
Reason string
|
||||
Allowed bool
|
||||
}
|
||||
|
||||
// ValidationCache caches credential validation results to reduce upstream checks
|
||||
|
||||
Vendored
+47
-14
@@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/analytics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
@@ -27,15 +28,21 @@ type ScannerInterface interface {
|
||||
CheckVulnerabilities(ctx context.Context, registry, packageName, version string) (blocked bool, reason string, err error)
|
||||
}
|
||||
|
||||
// AnalyticsInterface defines the interface for analytics tracking
|
||||
type AnalyticsInterface interface {
|
||||
TrackDownload(download analytics.PackageDownload)
|
||||
}
|
||||
|
||||
// Manager coordinates caching operations between storage and metadata
|
||||
type Manager struct {
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.MetadataStore
|
||||
scanner ScannerInterface
|
||||
config Config
|
||||
sf singleflight.Group
|
||||
mu sync.RWMutex
|
||||
evicting bool
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.MetadataStore
|
||||
scanner ScannerInterface
|
||||
analytics AnalyticsInterface
|
||||
sf singleflight.Group
|
||||
config Config
|
||||
mu sync.RWMutex
|
||||
evicting bool
|
||||
}
|
||||
|
||||
// Config holds cache manager configuration
|
||||
@@ -48,15 +55,15 @@ type Config struct {
|
||||
|
||||
// CacheEntry represents a cached package
|
||||
type CacheEntry struct {
|
||||
Package *metadata.Package
|
||||
Data io.ReadCloser
|
||||
FromCache bool
|
||||
Package *metadata.Package
|
||||
UpstreamURL string
|
||||
CacheControl string
|
||||
FromCache bool
|
||||
}
|
||||
|
||||
// New creates a new cache manager
|
||||
func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanner ScannerInterface, config Config) (*Manager, error) {
|
||||
func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanner ScannerInterface, analytics AnalyticsInterface, config Config) (*Manager, error) {
|
||||
if storage == nil {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "storage backend is required")
|
||||
}
|
||||
@@ -70,6 +77,11 @@ func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanne
|
||||
log.Info().Msg("Cache manager initialized with security scanning enabled")
|
||||
}
|
||||
|
||||
// Analytics is optional - can be nil if analytics tracking is disabled
|
||||
if analytics != nil {
|
||||
log.Info().Msg("Cache manager initialized with analytics tracking enabled")
|
||||
}
|
||||
|
||||
if config.DefaultTTL == 0 {
|
||||
config.DefaultTTL = 7 * 24 * time.Hour // 7 days default
|
||||
}
|
||||
@@ -87,10 +99,11 @@ func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanne
|
||||
}
|
||||
|
||||
manager := &Manager{
|
||||
storage: storage,
|
||||
metadata: metadata,
|
||||
scanner: scanner,
|
||||
config: config,
|
||||
storage: storage,
|
||||
metadata: metadata,
|
||||
scanner: scanner,
|
||||
analytics: analytics,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Start background cleanup worker
|
||||
@@ -134,6 +147,11 @@ func (m *Manager) getOrFetch(ctx context.Context, registry, name, version string
|
||||
metrics.RecordCacheHit(registry)
|
||||
_ = m.metadata.UpdateDownloadCount(ctx, registry, name, version) // #nosec G104 -- Async update, error logged
|
||||
|
||||
// Track download in analytics if enabled
|
||||
if m.analytics != nil {
|
||||
m.trackDownload(registry, name, version, pkg.Size)
|
||||
}
|
||||
|
||||
// Check for vulnerabilities if scanner is enabled
|
||||
if m.scanner != nil {
|
||||
blocked, reason, err := m.scanner.CheckVulnerabilities(ctx, registry, name, version)
|
||||
@@ -552,6 +570,21 @@ func (m *Manager) Health(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// trackDownload tracks a package download event in analytics
|
||||
func (m *Manager) trackDownload(registry, name, version string, size int64) {
|
||||
download := analytics.PackageDownload{
|
||||
Registry: registry,
|
||||
Name: name,
|
||||
Version: version,
|
||||
Timestamp: time.Now(),
|
||||
BytesSize: size,
|
||||
ClientIP: "", // TODO: Extract from context if available
|
||||
UserAgent: "", // TODO: Extract from context if available
|
||||
}
|
||||
|
||||
m.analytics.TrackDownload(download)
|
||||
}
|
||||
|
||||
// Close closes the cache manager
|
||||
func (m *Manager) Close() error {
|
||||
var err error
|
||||
|
||||
Vendored
+22
-22
@@ -197,12 +197,12 @@ func (m *MockMetadataStore) AggregateDownloadData(ctx context.Context) error {
|
||||
// TestNew tests cache manager creation
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storage storage.StorageBackend
|
||||
metadata metadata.MetadataStore
|
||||
name string
|
||||
errContains string
|
||||
config Config
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Valid configuration
|
||||
{
|
||||
@@ -262,7 +262,7 @@ func TestNew(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := New(tt.storage, tt.metadata, nil, tt.config)
|
||||
manager, err := New(tt.storage, tt.metadata, nil, nil, tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
@@ -295,15 +295,15 @@ func TestNew(t *testing.T) {
|
||||
// TestGet tests cache retrieval with various scenarios
|
||||
func TestGet(t *testing.T) {
|
||||
tests := []struct {
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
fetchFunc func(context.Context) (io.ReadCloser, string, error)
|
||||
name string
|
||||
registry string
|
||||
packageName string
|
||||
version string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
fetchFunc func(context.Context) (io.ReadCloser, string, error)
|
||||
errContains string
|
||||
wantFromCache bool
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
// GOOD: Cache hit
|
||||
{
|
||||
@@ -489,7 +489,7 @@ func TestGet(t *testing.T) {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
})
|
||||
@@ -523,13 +523,13 @@ func TestGet(t *testing.T) {
|
||||
// TestDelete tests package deletion
|
||||
func TestDelete(t *testing.T) {
|
||||
tests := []struct {
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
name string
|
||||
registry string
|
||||
packageName string
|
||||
version string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Successful deletion
|
||||
{
|
||||
@@ -615,7 +615,7 @@ func TestDelete(t *testing.T) {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -639,10 +639,10 @@ func TestDelete(t *testing.T) {
|
||||
// TestHealth tests health check functionality
|
||||
func TestHealth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
name string
|
||||
errContains string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Both healthy
|
||||
{
|
||||
@@ -692,7 +692,7 @@ func TestHealth(t *testing.T) {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -727,7 +727,7 @@ func TestGetStats(t *testing.T) {
|
||||
|
||||
mockMetadata.On("GetStats", mock.Anything, "npm").Return(expectedStats, nil)
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -741,8 +741,8 @@ func TestGetStats(t *testing.T) {
|
||||
// TestClose tests manager cleanup
|
||||
func TestClose(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Clean close
|
||||
@@ -792,7 +792,7 @@ func TestClose(t *testing.T) {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
@@ -812,11 +812,11 @@ func TestClose(t *testing.T) {
|
||||
// TestEvict tests LRU eviction
|
||||
func TestEvict(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
needed int64
|
||||
setupMock func(*MockStorageBackend, *MockMetadataStore)
|
||||
wantErr bool
|
||||
name string
|
||||
errContains string
|
||||
needed int64
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Successful eviction
|
||||
{
|
||||
@@ -881,7 +881,7 @@ func TestEvict(t *testing.T) {
|
||||
tt.setupMock(mockStorage, mockMetadata)
|
||||
}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -907,7 +907,7 @@ func TestGenerateStorageKey(t *testing.T) {
|
||||
mockStorage := &MockStorageBackend{}
|
||||
mockMetadata := &MockMetadataStore{}
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -954,7 +954,7 @@ func TestConcurrentGet(t *testing.T) {
|
||||
io.NopCloser(bytes.NewReader([]byte("test data"))), nil).Maybe()
|
||||
mockMetadata.On("UpdateDownloadCount", mock.Anything, "npm", "concurrent", "1.0.0").Return(nil).Maybe()
|
||||
|
||||
manager, err := New(mockStorage, mockMetadata, nil, Config{})
|
||||
manager, err := New(mockStorage, mockMetadata, nil, nil, Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
+1
-128
@@ -4,10 +4,7 @@ import (
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for ETag generation, not cryptographic security
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -214,38 +211,11 @@ func (m *Middleware) generateETag(body []byte) string {
|
||||
return `"` + hex.EncodeToString(hash[:]) + `"`
|
||||
}
|
||||
|
||||
// SetLastModified sets the Last-Modified header
|
||||
func SetLastModified(w http.ResponseWriter, t time.Time) {
|
||||
w.Header().Set("Last-Modified", t.UTC().Format(http.TimeFormat))
|
||||
}
|
||||
|
||||
// SetCacheControl sets a custom Cache-Control header
|
||||
func SetCacheControl(w http.ResponseWriter, cc CacheControl) {
|
||||
w.Header().Set("Cache-Control", cc.String())
|
||||
}
|
||||
|
||||
// SetNoCache sets headers to prevent caching
|
||||
func SetNoCache(w http.ResponseWriter) {
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
}
|
||||
|
||||
// SetImmutable sets headers for immutable content (content-addressed files)
|
||||
func SetImmutable(w http.ResponseWriter, maxAge int) {
|
||||
cc := CacheControl{
|
||||
Public: true,
|
||||
MaxAge: maxAge,
|
||||
Immutable: true,
|
||||
}
|
||||
w.Header().Set("Cache-Control", cc.String())
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture response
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body []byte
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
@@ -261,100 +231,3 @@ func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
rw.body = append(rw.body, b...)
|
||||
return rw.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// HandleRange handles HTTP Range requests for partial content
|
||||
func HandleRange(w http.ResponseWriter, r *http.Request, content io.ReadSeeker, size int64, modTime time.Time) error {
|
||||
// Set Last-Modified header
|
||||
SetLastModified(w, modTime)
|
||||
|
||||
// Check for Range header
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader == "" {
|
||||
// No range request - serve full content
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := io.Copy(w, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse range header (simplified - only handles single range)
|
||||
// Format: bytes=start-end
|
||||
var start, end int64
|
||||
n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end)
|
||||
if err != nil || n != 2 {
|
||||
// Invalid range - serve full content
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := io.Copy(w, content)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate range
|
||||
if start < 0 || start >= size || end < start || end >= size {
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
|
||||
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Seek to start position
|
||||
if _, err := content.Seek(start, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate content length
|
||||
contentLength := end - start + 1
|
||||
|
||||
// Set headers for partial content
|
||||
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, size))
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
// Copy range to response
|
||||
_, err = io.CopyN(w, content, contentLength)
|
||||
return err
|
||||
}
|
||||
|
||||
// DefaultCacheControl returns sensible defaults for different content types
|
||||
func DefaultCacheControl(contentType string, versioned bool) CacheControl {
|
||||
if versioned {
|
||||
// Content-addressed or versioned resources can be cached forever
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 31536000, // 1 year
|
||||
Immutable: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Default caching based on content type
|
||||
switch contentType {
|
||||
case "application/json":
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600, // 1 hour
|
||||
SMaxAge: 7200, // 2 hours for shared caches
|
||||
}
|
||||
case "application/octet-stream", "application/x-gzip", "application/zip":
|
||||
// Binary packages
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 86400, // 1 day
|
||||
SMaxAge: 604800, // 1 week for shared caches
|
||||
}
|
||||
case "text/html":
|
||||
// HTML should revalidate
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 0,
|
||||
MustRevalidate: true,
|
||||
}
|
||||
default:
|
||||
return CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600, // 1 hour default
|
||||
SMaxAge: 7200,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
package cdn
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type CDNMiddlewareTestSuite struct {
|
||||
suite.Suite
|
||||
middleware *Middleware
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) SetupTest() {
|
||||
s.middleware = NewMiddleware(Config{
|
||||
DefaultCacheControl: CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600,
|
||||
SMaxAge: 7200,
|
||||
},
|
||||
EnableETag: true,
|
||||
EnableVary: true,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCDNMiddlewareTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(CDNMiddlewareTestSuite))
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestCacheControlHeader() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(http.StatusOK, w.Code)
|
||||
s.Contains(w.Header().Get("Cache-Control"), "public")
|
||||
s.Contains(w.Header().Get("Cache-Control"), "max-age=3600")
|
||||
s.Contains(w.Header().Get("Cache-Control"), "s-maxage=7200")
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestETagGeneration() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response content"))
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(http.StatusOK, w.Code)
|
||||
etag := w.Header().Get("ETag")
|
||||
s.NotEmpty(etag)
|
||||
s.True(len(etag) > 0)
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestETagConsistencyAcrossRequests() {
|
||||
responseBody := []byte("test response content")
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(responseBody)
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
// First request to get ETag
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(w1, req1)
|
||||
etag := w1.Header().Get("ETag")
|
||||
s.NotEmpty(etag)
|
||||
s.Equal(http.StatusOK, w1.Code)
|
||||
|
||||
// Verify ETag is consistent for same content
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(w2, req2)
|
||||
etag2 := w2.Header().Get("ETag")
|
||||
s.Equal(etag, etag2, "ETag should be consistent for same content")
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestVaryHeader() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer token")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
vary := w.Header().Get("Vary")
|
||||
s.NotEmpty(vary)
|
||||
s.Contains(vary, "Accept-Encoding")
|
||||
s.Contains(vary, "Authorization")
|
||||
s.Contains(vary, "Accept")
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestCacheControlString() {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected string
|
||||
cc CacheControl
|
||||
}{
|
||||
{
|
||||
name: "public with max-age",
|
||||
cc: CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600,
|
||||
},
|
||||
expected: "public, max-age=3600",
|
||||
},
|
||||
{
|
||||
name: "private with no-cache",
|
||||
cc: CacheControl{
|
||||
Private: true,
|
||||
NoCache: true,
|
||||
},
|
||||
expected: "private, no-cache",
|
||||
},
|
||||
{
|
||||
name: "immutable",
|
||||
cc: CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 31536000,
|
||||
Immutable: true,
|
||||
},
|
||||
expected: "public, immutable, max-age=31536000",
|
||||
},
|
||||
{
|
||||
name: "no-store",
|
||||
cc: CacheControl{
|
||||
NoStore: true,
|
||||
},
|
||||
expected: "no-store",
|
||||
},
|
||||
{
|
||||
name: "must-revalidate",
|
||||
cc: CacheControl{
|
||||
Public: true,
|
||||
MustRevalidate: true,
|
||||
},
|
||||
expected: "public, must-revalidate",
|
||||
},
|
||||
{
|
||||
name: "s-maxage",
|
||||
cc: CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600,
|
||||
SMaxAge: 7200,
|
||||
},
|
||||
expected: "public, max-age=3600, s-maxage=7200",
|
||||
},
|
||||
{
|
||||
name: "stale-while-revalidate",
|
||||
cc: CacheControl{
|
||||
Public: true,
|
||||
MaxAge: 3600,
|
||||
StaleWhileRevalidate: 86400,
|
||||
},
|
||||
expected: "public, max-age=3600, stale-while-revalidate=86400",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := tt.cc.String()
|
||||
// Check that all expected parts are in the result
|
||||
for _, part := range splitCacheControl(tt.expected) {
|
||||
s.Contains(result, part)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestGenerateETag() {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
expected bool // true if ETag should be generated
|
||||
}{
|
||||
{
|
||||
name: "non-empty body",
|
||||
body: []byte("test content"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: []byte{},
|
||||
expected: true, // Empty body still generates ETag (MD5 of empty string)
|
||||
},
|
||||
{
|
||||
name: "nil body",
|
||||
body: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
etag := s.middleware.generateETag(tt.body)
|
||||
if tt.expected {
|
||||
s.NotEmpty(etag)
|
||||
s.True(len(etag) > 2) // Should be quoted
|
||||
} else {
|
||||
s.Empty(etag)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestETagConsistency() {
|
||||
// Same content should produce same ETag
|
||||
body := []byte("consistent content")
|
||||
etag1 := s.middleware.generateETag(body)
|
||||
etag2 := s.middleware.generateETag(body)
|
||||
|
||||
s.Equal(etag1, etag2)
|
||||
|
||||
// Different content should produce different ETag
|
||||
body2 := []byte("different content")
|
||||
etag3 := s.middleware.generateETag(body2)
|
||||
|
||||
s.NotEqual(etag1, etag3)
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestNoCacheFor4xxErrors() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte("not found"))
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(http.StatusNotFound, w.Code)
|
||||
// 4xx errors should not have cache headers applied
|
||||
// (based on the middleware only applying headers for 2xx status codes)
|
||||
}
|
||||
|
||||
func (s *CDNMiddlewareTestSuite) TestNoCacheFor5xxErrors() {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("error"))
|
||||
})
|
||||
|
||||
wrappedHandler := s.middleware.Handler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
|
||||
s.Equal(http.StatusInternalServerError, w.Code)
|
||||
// 5xx errors should not have cache headers applied
|
||||
}
|
||||
|
||||
// Helper function to split cache-control string
|
||||
func splitCacheControl(s string) []string {
|
||||
var parts []string
|
||||
current := ""
|
||||
for _, char := range s {
|
||||
if char == ',' {
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
current = ""
|
||||
}
|
||||
} else if char != ' ' {
|
||||
current += string(char)
|
||||
}
|
||||
}
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
+45
-43
@@ -7,42 +7,42 @@ import (
|
||||
|
||||
// Config is the main configuration struct
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server" json:"server"`
|
||||
Storage StorageConfig `mapstructure:"storage" json:"storage"`
|
||||
Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"`
|
||||
Cache CacheConfig `mapstructure:"cache" json:"cache"`
|
||||
Security SecurityConfig `mapstructure:"security" json:"security"`
|
||||
Auth AuthConfig `mapstructure:"auth" json:"auth"`
|
||||
Network NetworkConfig `mapstructure:"network" json:"network"`
|
||||
Logging LoggingConfig `mapstructure:"logging" json:"logging"`
|
||||
Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"`
|
||||
Handlers HandlersConfig `mapstructure:"handlers" json:"handlers"`
|
||||
Server ServerConfig `mapstructure:"server" json:"server"`
|
||||
Logging LoggingConfig `mapstructure:"logging" json:"logging"`
|
||||
Network NetworkConfig `mapstructure:"network" json:"network"`
|
||||
Auth AuthConfig `mapstructure:"auth" json:"auth"`
|
||||
}
|
||||
|
||||
// ServerConfig contains HTTP server configuration
|
||||
type ServerConfig struct {
|
||||
TLS TLSConfig `mapstructure:"tls" json:"tls"`
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"`
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout" json:"idle_timeout"`
|
||||
TLS TLSConfig `mapstructure:"tls" json:"tls"`
|
||||
}
|
||||
|
||||
// TLSConfig contains TLS/HTTPS configuration
|
||||
type TLSConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
CertFile string `mapstructure:"cert_file" json:"cert_file"`
|
||||
KeyFile string `mapstructure:"key_file" json:"key_file"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// StorageConfig contains storage backend configuration
|
||||
type StorageConfig struct {
|
||||
Backend string `mapstructure:"backend" json:"backend"` // filesystem, s3, smb, nfs
|
||||
Options map[string]interface{} `mapstructure:"options" json:"options"`
|
||||
SMB SMBConfig `mapstructure:"smb" json:"smb"`
|
||||
Backend string `mapstructure:"backend" json:"backend"`
|
||||
Path string `mapstructure:"path" json:"path"`
|
||||
Filesystem FilesystemConfig `mapstructure:"filesystem" json:"filesystem"`
|
||||
S3 S3Config `mapstructure:"s3" json:"s3"`
|
||||
SMB SMBConfig `mapstructure:"smb" json:"smb"`
|
||||
Options map[string]interface{} `mapstructure:"options" json:"options"`
|
||||
}
|
||||
|
||||
// FilesystemConfig contains local filesystem storage configuration
|
||||
@@ -52,12 +52,14 @@ type FilesystemConfig struct {
|
||||
|
||||
// S3Config contains S3-compatible storage configuration
|
||||
type S3Config struct {
|
||||
Endpoint string `mapstructure:"endpoint" json:"endpoint"`
|
||||
Region string `mapstructure:"region" json:"region"`
|
||||
Bucket string `mapstructure:"bucket" json:"bucket"`
|
||||
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"`
|
||||
SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // Don't serialize secrets
|
||||
UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"`
|
||||
Endpoint string `mapstructure:"endpoint" json:"endpoint"` // Optional: for MinIO, etc.
|
||||
Region string `mapstructure:"region" json:"region"` // AWS region (e.g., us-east-1)
|
||||
Bucket string `mapstructure:"bucket" json:"bucket"` // S3 bucket name
|
||||
Prefix string `mapstructure:"prefix" json:"prefix"` // Optional: key prefix
|
||||
AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"` // AWS access key
|
||||
SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // AWS secret key (not serialized)
|
||||
ForcePathStyle bool `mapstructure:"force_path_style" json:"force_path_style"` // For MinIO compatibility
|
||||
UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"` // Deprecated: use endpoint with https://
|
||||
}
|
||||
|
||||
// SMBConfig contains SMB/CIFS storage configuration
|
||||
@@ -71,10 +73,10 @@ type SMBConfig struct {
|
||||
|
||||
// MetadataConfig contains metadata store configuration
|
||||
type MetadataConfig struct {
|
||||
Backend string `mapstructure:"backend" json:"backend"` // sqlite, postgresql, file
|
||||
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
|
||||
Backend string `mapstructure:"backend" json:"backend"`
|
||||
Connection string `mapstructure:"connection" json:"connection"`
|
||||
SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"`
|
||||
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
|
||||
}
|
||||
|
||||
// SQLiteConfig contains SQLite-specific configuration
|
||||
@@ -86,33 +88,33 @@ type SQLiteConfig struct {
|
||||
// PostgreSQLConfig contains PostgreSQL-specific configuration
|
||||
type PostgreSQLConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
Database string `mapstructure:"database" json:"database"`
|
||||
User string `mapstructure:"user" json:"user"`
|
||||
Password string `mapstructure:"password" json:"-"` // Don't serialize secrets
|
||||
Password string `mapstructure:"password" json:"-"`
|
||||
SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache management configuration
|
||||
type CacheConfig struct {
|
||||
TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"`
|
||||
DefaultTTL time.Duration `mapstructure:"default_ttl" json:"default_ttl"`
|
||||
CleanupInterval time.Duration `mapstructure:"cleanup_interval" json:"cleanup_interval"`
|
||||
MaxSizeBytes int64 `mapstructure:"max_size_bytes" json:"max_size_bytes"`
|
||||
PerProjectQuota int64 `mapstructure:"per_project_quota" json:"per_project_quota"`
|
||||
TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"` // Per ecosystem
|
||||
}
|
||||
|
||||
// SecurityConfig contains security scanning configuration
|
||||
type SecurityConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"` // Scan packages on first download
|
||||
RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"` // How often to re-scan (e.g., 24h, 168h for weekly)
|
||||
BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"` // none, low, medium, high, critical
|
||||
BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"` // Max vulns per severity before blocking
|
||||
UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"` // Update vulnerability databases on startup
|
||||
AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"` // Packages that bypass security checks (format: "registry/name@version" or "registry/name")
|
||||
IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"` // CVE IDs to ignore globally (e.g., "CVE-2021-23337")
|
||||
Scanners ScannersConfig `mapstructure:"scanners" json:"scanners"`
|
||||
BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"`
|
||||
AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"`
|
||||
IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"`
|
||||
BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"`
|
||||
RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"`
|
||||
UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"`
|
||||
}
|
||||
|
||||
// VulnerabilityThresholds defines max allowed vulnerabilities per severity
|
||||
@@ -126,36 +128,36 @@ type VulnerabilityThresholds struct {
|
||||
// ScannersConfig contains individual scanner configurations
|
||||
type ScannersConfig struct {
|
||||
Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"`
|
||||
OSV OSVConfig `mapstructure:"osv" json:"osv"`
|
||||
GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"`
|
||||
Static StaticConfig `mapstructure:"static" json:"static"`
|
||||
OSV OSVConfig `mapstructure:"osv" json:"osv"`
|
||||
Grype GrypeConfig `mapstructure:"grype" json:"grype"`
|
||||
Govulncheck GovulncheckConfig `mapstructure:"govulncheck" json:"govulncheck"`
|
||||
NpmAudit NpmAuditConfig `mapstructure:"npm_audit" json:"npm_audit"`
|
||||
PipAudit PipAuditConfig `mapstructure:"pip_audit" json:"pip_audit"`
|
||||
GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"`
|
||||
}
|
||||
|
||||
// TrivyConfig contains Trivy scanner configuration
|
||||
type TrivyConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
CacheDB string `mapstructure:"cache_db" json:"cache_db"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// OSVConfig contains OSV scanner configuration
|
||||
type OSVConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
APIURL string `mapstructure:"api_url" json:"api_url"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// StaticConfig contains static analysis configuration
|
||||
type StaticConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"`
|
||||
MaxPackageSize int64 `mapstructure:"max_package_size" json:"max_package_size"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
CheckChecksums bool `mapstructure:"check_checksums" json:"check_checksums"`
|
||||
BlockSuspicious bool `mapstructure:"block_suspicious" json:"block_suspicious"`
|
||||
AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"`
|
||||
}
|
||||
|
||||
// GrypeConfig contains Grype scanner configuration
|
||||
@@ -184,16 +186,16 @@ type PipAuditConfig struct {
|
||||
|
||||
// GHSAConfig contains GitHub Advisory Database scanner configuration
|
||||
type GHSAConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
Token string `mapstructure:"token" json:"-"`
|
||||
Timeout time.Duration `mapstructure:"timeout" json:"timeout"`
|
||||
Token string `mapstructure:"token" json:"-"` // GitHub token for higher rate limits (don't serialize)
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// AuthConfig contains authentication configuration
|
||||
type AuthConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
KeyExpiration time.Duration `mapstructure:"key_expiration" json:"key_expiration"`
|
||||
BcryptCost int `mapstructure:"bcrypt_cost" json:"bcrypt_cost"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
AuditLog bool `mapstructure:"audit_log" json:"audit_log"`
|
||||
}
|
||||
|
||||
@@ -245,24 +247,24 @@ type HandlersConfig struct {
|
||||
|
||||
// GoHandlerConfig contains Go proxy configuration
|
||||
type GoHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamProxy string `mapstructure:"upstream_proxy" json:"upstream_proxy"`
|
||||
ChecksumDB string `mapstructure:"checksum_db" json:"checksum_db"`
|
||||
GitCredentialsFile string `mapstructure:"git_credentials_file" json:"git_credentials_file"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
VerifyChecksums bool `mapstructure:"verify_checksums" json:"verify_checksums"`
|
||||
GitCredentialsFile string `mapstructure:"git_credentials_file" json:"git_credentials_file"` // Path to git credentials JSON file
|
||||
}
|
||||
|
||||
// NPMHandlerConfig contains NPM registry configuration
|
||||
type NPMHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamRegistry string `mapstructure:"upstream_registry" json:"upstream_registry"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// PyPIHandlerConfig contains PyPI configuration
|
||||
type PyPIHandlerConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
UpstreamURL string `mapstructure:"upstream_url" json:"upstream_url"`
|
||||
SimpleAPIURL string `mapstructure:"simple_api_url" json:"simple_api_url"`
|
||||
Enabled bool `mapstructure:"enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
// Default returns a configuration with sensible defaults
|
||||
|
||||
@@ -41,10 +41,10 @@ func (s *ConfigTestSuite) TestDefault() {
|
||||
|
||||
func (s *ConfigTestSuite) TestValidate() {
|
||||
tests := []struct {
|
||||
name string
|
||||
modify func(*Config)
|
||||
expectError bool
|
||||
name string
|
||||
errorSubstr string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid_config",
|
||||
@@ -175,11 +175,11 @@ func (s *ConfigTestSuite) TestValidate() {
|
||||
|
||||
func (s *ConfigTestSuite) TestLoad() {
|
||||
tests := []struct {
|
||||
envVars map[string]string
|
||||
validate func(*Config)
|
||||
name string
|
||||
configYAML string
|
||||
envVars map[string]string
|
||||
expectError bool
|
||||
validate func(*Config)
|
||||
}{
|
||||
{
|
||||
name: "valid_yaml_config",
|
||||
@@ -319,13 +319,6 @@ func (s *ConfigTestSuite) TestLoadMissingFile() {
|
||||
s.Nil(cfg)
|
||||
}
|
||||
|
||||
func (s *ConfigTestSuite) TestLoadWithDefaults() {
|
||||
// Invalid config path should return defaults
|
||||
cfg := LoadWithDefaults("/invalid/path/config.yaml")
|
||||
s.NotNil(cfg)
|
||||
s.Equal(8080, cfg.Server.Port)
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefault(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -344,8 +337,8 @@ func BenchmarkValidate(b *testing.B) {
|
||||
// Table-driven edge cases
|
||||
func TestConfigEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
name string
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -51,12 +51,3 @@ func Load(configPath string) (*Config, error) {
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// LoadWithDefaults loads configuration or returns defaults on error
|
||||
func LoadWithDefaults(configPath string) *Config {
|
||||
cfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
return Default()
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -58,11 +58,3 @@ var HTTPStatusCode = map[string]int{
|
||||
ErrCodeServiceUnavailable: 503,
|
||||
ErrCodeCircuitOpen: 503,
|
||||
}
|
||||
|
||||
// GetHTTPStatus returns the HTTP status code for an error code
|
||||
func GetHTTPStatus(code string) int {
|
||||
if status, ok := HTTPStatusCode[code]; ok {
|
||||
return status
|
||||
}
|
||||
return 500 // Default to internal server error
|
||||
}
|
||||
|
||||
+4
-50
@@ -6,11 +6,11 @@ import (
|
||||
|
||||
// Error represents a structured error with code and details
|
||||
type Error struct {
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Cause error `json:"-"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Trace []string `json:"trace,omitempty"`
|
||||
Cause error `json:"-"` // Internal cause, not serialized
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
@@ -34,26 +34,12 @@ func New(code, message string) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
// Newf creates a new error with formatted message
|
||||
func Newf(code, format string, args ...interface{}) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithDetails adds details to the error
|
||||
func (e *Error) WithDetails(details interface{}) *Error {
|
||||
e.Details = details
|
||||
return e
|
||||
}
|
||||
|
||||
// WithTrace adds stack trace to the error
|
||||
func (e *Error) WithTrace(trace []string) *Error {
|
||||
e.Trace = trace
|
||||
return e
|
||||
}
|
||||
|
||||
// WithCause adds an underlying cause to the error
|
||||
func (e *Error) WithCause(cause error) *Error {
|
||||
e.Cause = cause
|
||||
@@ -69,44 +55,12 @@ func Wrap(err error, code, message string) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapf wraps an existing error with formatted message
|
||||
func Wrapf(err error, code, format string, args ...interface{}) *Error {
|
||||
return &Error{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Common error constructors
|
||||
func BadRequest(message string) *Error {
|
||||
return New(ErrCodeBadRequest, message)
|
||||
}
|
||||
|
||||
func Unauthorized(message string) *Error {
|
||||
return New(ErrCodeUnauthorized, message)
|
||||
}
|
||||
|
||||
func Forbidden(message string) *Error {
|
||||
return New(ErrCodeForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound creates a not found error
|
||||
func NotFound(message string) *Error {
|
||||
return New(ErrCodeNotFound, message)
|
||||
}
|
||||
|
||||
func InternalServer(message string) *Error {
|
||||
return New(ErrCodeInternalServer, message)
|
||||
}
|
||||
|
||||
func PackageNotFound(name, version string) *Error {
|
||||
return New(ErrCodePackageNotFound, fmt.Sprintf("Package %s@%s not found", name, version)).
|
||||
WithDetails(map[string]string{
|
||||
"package": name,
|
||||
"version": version,
|
||||
})
|
||||
}
|
||||
|
||||
// QuotaExceeded creates a quota exceeded error
|
||||
func QuotaExceeded(limit int64) *Error {
|
||||
return New(ErrCodeQuotaExceeded, "Storage quota exceeded").
|
||||
WithDetails(map[string]interface{}{
|
||||
|
||||
+1
-131
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -46,43 +45,10 @@ func (s *ErrorsTestSuite) TestNew() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestNewf() {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
format string
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "formatted_message",
|
||||
code: ErrCodePackageNotFound,
|
||||
format: "Package %s@%s not found",
|
||||
args: []interface{}{"react", "18.2.0"},
|
||||
expected: "Package react@18.2.0 not found",
|
||||
},
|
||||
{
|
||||
name: "no_args",
|
||||
code: ErrCodeInternalServer,
|
||||
format: "Internal error",
|
||||
args: []interface{}{},
|
||||
expected: "Internal error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := Newf(tt.code, tt.format, tt.args...)
|
||||
s.Equal(tt.code, err.Code)
|
||||
s.Equal(tt.expected, err.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithDetails() {
|
||||
tests := []struct {
|
||||
name string
|
||||
details interface{}
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "map_details",
|
||||
@@ -106,12 +72,6 @@ func (s *ErrorsTestSuite) TestWithDetails() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithTrace() {
|
||||
trace := []string{"file1.go:10", "file2.go:20"}
|
||||
err := New(ErrCodeInternalServer, "test").WithTrace(trace)
|
||||
s.Equal(trace, err.Trace)
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWithCause() {
|
||||
cause := errors.New("underlying error")
|
||||
err := New(ErrCodeStorageFailure, "test").WithCause(cause)
|
||||
@@ -129,15 +89,6 @@ func (s *ErrorsTestSuite) TestWrap() {
|
||||
s.True(errors.Is(wrapped, cause))
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestWrapf() {
|
||||
cause := errors.New("connection refused")
|
||||
wrapped := Wrapf(cause, ErrCodeUpstreamFailure, "failed to connect to %s", "registry.npmjs.org")
|
||||
|
||||
s.Equal(ErrCodeUpstreamFailure, wrapped.Code)
|
||||
s.Equal("failed to connect to registry.npmjs.org", wrapped.Message)
|
||||
s.Equal(cause, wrapped.Cause)
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestErrorString() {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -163,59 +114,6 @@ func (s *ErrorsTestSuite) TestErrorString() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestCommonConstructors() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() *Error
|
||||
wantCode string
|
||||
}{
|
||||
{
|
||||
name: "bad_request",
|
||||
fn: func() *Error { return BadRequest("invalid input") },
|
||||
wantCode: ErrCodeBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fn: func() *Error { return Unauthorized("invalid token") },
|
||||
wantCode: ErrCodeUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden",
|
||||
fn: func() *Error { return Forbidden("access denied") },
|
||||
wantCode: ErrCodeForbidden,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
fn: func() *Error { return NotFound("resource missing") },
|
||||
wantCode: ErrCodeNotFound,
|
||||
},
|
||||
{
|
||||
name: "internal_server",
|
||||
fn: func() *Error { return InternalServer("server error") },
|
||||
wantCode: ErrCodeInternalServer,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := tt.fn()
|
||||
s.Equal(tt.wantCode, err.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestPackageNotFound() {
|
||||
err := PackageNotFound("lodash", "4.17.21")
|
||||
s.Equal(ErrCodePackageNotFound, err.Code)
|
||||
s.Equal("Package lodash@4.17.21 not found", err.Message)
|
||||
s.NotNil(err.Details)
|
||||
|
||||
details, ok := err.Details.(map[string]string)
|
||||
s.True(ok)
|
||||
s.Equal("lodash", details["package"])
|
||||
s.Equal("4.17.21", details["version"])
|
||||
}
|
||||
|
||||
func (s *ErrorsTestSuite) TestQuotaExceeded() {
|
||||
limit := int64(1000000)
|
||||
err := QuotaExceeded(limit)
|
||||
@@ -275,31 +173,3 @@ func (s *ErrorsTestSuite) TestEdgeCases() {
|
||||
s.Equal(largeDetails, err.Details)
|
||||
})
|
||||
}
|
||||
|
||||
// Table-driven test for error codes
|
||||
func TestGetHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
code string
|
||||
expectedStatus int
|
||||
}{
|
||||
{ErrCodeBadRequest, 400},
|
||||
{ErrCodeUnauthorized, 401},
|
||||
{ErrCodeForbidden, 403},
|
||||
{ErrCodeNotFound, 404},
|
||||
{ErrCodeConflict, 409},
|
||||
{ErrCodePayloadTooLarge, 413},
|
||||
{ErrCodeChecksumMismatch, 422},
|
||||
{ErrCodeRateLimited, 429},
|
||||
{ErrCodeInternalServer, 500},
|
||||
{ErrCodeDatabaseFailure, 500},
|
||||
{ErrCodeUpstreamFailure, 502},
|
||||
{ErrCodeServiceUnavailable, 503},
|
||||
{"UNKNOWN_CODE", 500}, // Default
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.code, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expectedStatus, GetHTTPStatus(tt.code))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// Response is the standard API response envelope
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error *ErrorResponse `json:"error,omitempty"`
|
||||
Metadata *ResponseMeta `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse contains error details
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Trace []string `json:"trace,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseMeta contains request metadata
|
||||
type ResponseMeta struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// WriteJSON writes a success response as JSON
|
||||
func WriteJSON(w http.ResponseWriter, statusCode int, data interface{}, meta *ResponseMeta) {
|
||||
response := Response{
|
||||
Success: statusCode < 400,
|
||||
Data: data,
|
||||
Metadata: meta,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
// Fallback to simple error response
|
||||
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode response"}}`, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteError writes an error response as JSON
|
||||
func WriteError(w http.ResponseWriter, statusCode int, err *Error, meta *ResponseMeta) {
|
||||
errResp := &ErrorResponse{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
Details: err.Details,
|
||||
Trace: err.Trace,
|
||||
}
|
||||
|
||||
response := Response{
|
||||
Success: false,
|
||||
Error: errResp,
|
||||
Metadata: meta,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if encErr := json.NewEncoder(w).Encode(response); encErr != nil {
|
||||
// Fallback to simple error response
|
||||
http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode error response"}}`, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorSimple writes an error without metadata
|
||||
func WriteErrorSimple(w http.ResponseWriter, err *Error) {
|
||||
statusCode := GetHTTPStatus(err.Code)
|
||||
meta := &ResponseMeta{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
WriteError(w, statusCode, err, meta)
|
||||
}
|
||||
|
||||
// WriteJSONSimple writes a success response without metadata
|
||||
func WriteJSONSimple(w http.ResponseWriter, statusCode int, data interface{}) {
|
||||
meta := &ResponseMeta{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
WriteJSON(w, statusCode, data, meta)
|
||||
}
|
||||
@@ -21,25 +21,25 @@ const (
|
||||
|
||||
// Check represents a single health check
|
||||
type Check struct {
|
||||
Fn func(context.Context) (Status, string) `json:"-"`
|
||||
Name string `json:"name"`
|
||||
Status Status `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Fn func(context.Context) (Status, string) `json:"-"`
|
||||
}
|
||||
|
||||
// Response is the health check response
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data *HealthData `json:"data,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// HealthData contains health check data
|
||||
type HealthData struct {
|
||||
Components map[string]*Component `json:"components"`
|
||||
Status Status `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
Components map[string]*Component `json:"components"`
|
||||
}
|
||||
|
||||
// Component represents a system component
|
||||
@@ -57,8 +57,8 @@ type Metadata struct {
|
||||
|
||||
// Checker manages health checks
|
||||
type Checker struct {
|
||||
checks []*Check
|
||||
startTime time.Time
|
||||
checks []*Check
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLockNotAcquired = errors.New("lock not acquired")
|
||||
ErrLockNotHeld = errors.New("lock not held by this instance")
|
||||
ErrInvalidTTL = errors.New("invalid TTL: must be positive")
|
||||
)
|
||||
|
||||
// Lock represents a distributed lock
|
||||
type Lock struct {
|
||||
client *redis.Client
|
||||
key string
|
||||
value string
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Manager manages distributed locks using Redis
|
||||
type Manager struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// Config holds Redis connection configuration
|
||||
type Config struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
// NewManager creates a new lock manager
|
||||
func NewManager(cfg Config) (*Manager, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("addr", cfg.Addr).
|
||||
Int("db", cfg.DB).
|
||||
Msg("Connected to Redis for distributed locking")
|
||||
|
||||
return &Manager{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Acquire attempts to acquire a lock with the given key and TTL
|
||||
// Returns a Lock instance if successful, or an error if the lock is already held
|
||||
func (m *Manager) Acquire(ctx context.Context, key string, ttl time.Duration) (*Lock, error) {
|
||||
if ttl <= 0 {
|
||||
return nil, ErrInvalidTTL
|
||||
}
|
||||
|
||||
// Generate unique value for this lock instance
|
||||
value, err := generateLockValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to acquire lock using SET NX (set if not exists)
|
||||
success, err := m.client.SetNX(ctx, key, value, ttl).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to acquire lock")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !success {
|
||||
log.Debug().
|
||||
Str("key", key).
|
||||
Msg("Lock already held by another instance")
|
||||
return nil, ErrLockNotAcquired
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", key).
|
||||
Dur("ttl", ttl).
|
||||
Msg("Lock acquired successfully")
|
||||
|
||||
return &Lock{
|
||||
client: m.client,
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a lock, retrying for the specified duration
|
||||
// Returns a Lock instance if successful within the timeout, or an error
|
||||
func (m *Manager) TryAcquire(ctx context.Context, key string, ttl, timeout time.Duration) (*Lock, error) {
|
||||
if ttl <= 0 {
|
||||
return nil, ErrInvalidTTL
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
lock, err := m.Acquire(ctx, key, ttl)
|
||||
if err == nil {
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
if err != ErrLockNotAcquired {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, ErrLockNotAcquired
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release releases the lock
|
||||
// Returns an error if the lock is not held by this instance
|
||||
func (l *Lock) Release(ctx context.Context) error {
|
||||
// Use Lua script to ensure atomic check-and-delete
|
||||
// Only delete if the value matches (ensures we own the lock)
|
||||
script := redis.NewScript(`
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`)
|
||||
|
||||
result, err := script.Run(ctx, l.client, []string{l.key}, l.value).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", l.key).
|
||||
Msg("Failed to release lock")
|
||||
return err
|
||||
}
|
||||
|
||||
// Result of 0 means the lock was not deleted (not owned by us)
|
||||
if result.(int64) == 0 {
|
||||
log.Warn().
|
||||
Str("key", l.key).
|
||||
Msg("Attempted to release lock not held by this instance")
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", l.key).
|
||||
Msg("Lock released successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extend extends the lock TTL
|
||||
// Returns an error if the lock is not held by this instance
|
||||
func (l *Lock) Extend(ctx context.Context, additionalTTL time.Duration) error {
|
||||
// Use Lua script to ensure atomic check-and-extend
|
||||
script := redis.NewScript(`
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`)
|
||||
|
||||
newTTL := l.ttl + additionalTTL
|
||||
result, err := script.Run(ctx, l.client, []string{l.key}, l.value, int(newTTL.Seconds())).Result()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", l.key).
|
||||
Msg("Failed to extend lock")
|
||||
return err
|
||||
}
|
||||
|
||||
if result.(int64) == 0 {
|
||||
log.Warn().
|
||||
Str("key", l.key).
|
||||
Msg("Attempted to extend lock not held by this instance")
|
||||
return ErrLockNotHeld
|
||||
}
|
||||
|
||||
l.ttl = newTTL
|
||||
log.Debug().
|
||||
Str("key", l.key).
|
||||
Dur("new_ttl", newTTL).
|
||||
Msg("Lock TTL extended")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHeld checks if the lock is still held by this instance
|
||||
func (l *Lock) IsHeld(ctx context.Context) bool {
|
||||
value, err := l.client.Get(ctx, l.key).Result()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value == l.value
|
||||
}
|
||||
|
||||
// Close closes the lock manager and its Redis connection
|
||||
func (m *Manager) Close() error {
|
||||
return m.client.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
|
||||
// generateLockValue generates a cryptographically random lock value
|
||||
func generateLockValue() (string, error) {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// WithLock executes a function while holding a distributed lock
|
||||
// The lock is automatically released when the function returns
|
||||
func (m *Manager) WithLock(ctx context.Context, key string, ttl time.Duration, fn func(context.Context) error) error {
|
||||
lock, err := m.Acquire(ctx, key, ttl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Release(context.Background()); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to release lock in defer")
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// WithRetryLock executes a function while holding a distributed lock
|
||||
// It retries acquisition for the specified timeout duration
|
||||
func (m *Manager) WithRetryLock(ctx context.Context, key string, ttl, timeout time.Duration, fn func(context.Context) error) error {
|
||||
lock, err := m.TryAcquire(ctx, key, ttl, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Release(context.Background()); err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("key", key).
|
||||
Msg("Failed to release lock in defer")
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx)
|
||||
}
|
||||
@@ -35,23 +35,3 @@ func Init(cfg Config) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the global logger
|
||||
func Get() *zerolog.Logger {
|
||||
return &log.Logger
|
||||
}
|
||||
|
||||
// WithFields returns a logger with additional fields
|
||||
func WithFields(fields map[string]interface{}) *zerolog.Logger {
|
||||
logger := log.Logger
|
||||
for k, v := range fields {
|
||||
logger = logger.With().Interface(k, v).Logger()
|
||||
}
|
||||
return &logger
|
||||
}
|
||||
|
||||
// WithRequestID returns a logger with request ID
|
||||
func WithRequestID(requestID string) *zerolog.Logger {
|
||||
logger := log.With().Str("request_id", requestID).Logger()
|
||||
return &logger
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written int64
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
n, err := rw.ResponseWriter.Write(b)
|
||||
rw.written += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Middleware is HTTP middleware for request logging
|
||||
func Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Generate request ID
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = "req_" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
// Wrap response writer
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
// Set request ID in response header
|
||||
rw.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// Log request
|
||||
duration := time.Since(start)
|
||||
log.Info().
|
||||
Str("request_id", requestID).
|
||||
Str("method", r.Method).
|
||||
Str("path", r.URL.Path).
|
||||
Str("remote_addr", r.RemoteAddr).
|
||||
Str("user_agent", r.UserAgent()).
|
||||
Int("status", rw.statusCode).
|
||||
Int64("bytes", rw.written).
|
||||
Dur("duration_ms", duration).
|
||||
Msg("HTTP request")
|
||||
})
|
||||
}
|
||||
+32
-32
@@ -68,37 +68,37 @@ type MetadataStore interface {
|
||||
|
||||
// Package represents package metadata
|
||||
type Package struct {
|
||||
CachedAt time.Time `json:"cached_at"`
|
||||
LastAccessed time.Time `json:"last_accessed"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
UpstreamURL string `json:"upstream_url"`
|
||||
ChecksumMD5 string `json:"checksum_md5"`
|
||||
ChecksumSHA256 string `json:"checksum_sha256"`
|
||||
ID string `json:"id"`
|
||||
Registry string `json:"registry"` // npm, pypi, go
|
||||
Name string `json:"name"` // Package name
|
||||
Version string `json:"version"` // Package version
|
||||
StorageKey string `json:"storage_key"` // Key in storage backend
|
||||
Size int64 `json:"size"` // Package size in bytes
|
||||
ChecksumMD5 string `json:"checksum_md5"` // MD5 checksum
|
||||
ChecksumSHA256 string `json:"checksum_sha256"` // SHA256 checksum
|
||||
UpstreamURL string `json:"upstream_url"` // Original upstream URL
|
||||
CachedAt time.Time `json:"cached_at"` // When cached
|
||||
LastAccessed time.Time `json:"last_accessed"` // Last access time
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never)
|
||||
DownloadCount int64 `json:"download_count"` // Download counter
|
||||
Metadata map[string]string `json:"metadata"` // Additional metadata
|
||||
SecurityScanned bool `json:"security_scanned"` // Has been scanned
|
||||
RequiresAuth bool `json:"requires_auth"` // Package requires authentication
|
||||
AuthProvider string `json:"auth_provider"` // Auth provider (github.com, npm.pkg.github.com, etc.)
|
||||
StorageKey string `json:"storage_key"`
|
||||
Version string `json:"version"`
|
||||
Name string `json:"name"`
|
||||
Registry string `json:"registry"`
|
||||
AuthProvider string `json:"auth_provider"`
|
||||
Size int64 `json:"size"`
|
||||
DownloadCount int64 `json:"download_count"`
|
||||
SecurityScanned bool `json:"security_scanned"`
|
||||
RequiresAuth bool `json:"requires_auth"`
|
||||
}
|
||||
|
||||
// ScanResult represents a security scan result
|
||||
type ScanResult struct {
|
||||
ScannedAt time.Time `json:"scanned_at"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
ID string `json:"id"`
|
||||
Registry string `json:"registry"`
|
||||
PackageName string `json:"package_name"`
|
||||
PackageVersion string `json:"package_version"`
|
||||
Scanner string `json:"scanner"` // trivy, osv, etc.
|
||||
ScannedAt time.Time `json:"scanned_at"`
|
||||
Status ScanStatus `json:"status"` // clean, vulnerable, error
|
||||
VulnerabilityCount int `json:"vulnerability_count"`
|
||||
Scanner string `json:"scanner"`
|
||||
Status ScanStatus `json:"status"`
|
||||
Vulnerabilities []Vulnerability `json:"vulnerabilities"`
|
||||
Details map[string]interface{} `json:"details"` // Scanner-specific details
|
||||
VulnerabilityCount int `json:"vulnerability_count"`
|
||||
}
|
||||
|
||||
// Vulnerability represents a security vulnerability
|
||||
@@ -143,13 +143,13 @@ const (
|
||||
|
||||
// Stats represents metadata statistics
|
||||
type Stats struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Registry string `json:"registry"`
|
||||
TotalPackages int64 `json:"total_packages"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
TotalDownloads int64 `json:"total_downloads"`
|
||||
ScannedPackages int64 `json:"scanned_packages"`
|
||||
VulnerablePackages int64 `json:"vulnerable_packages"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
}
|
||||
|
||||
// TimeSeriesDataPoint represents a single data point in time-series
|
||||
@@ -198,14 +198,14 @@ type BypassListOptions struct {
|
||||
|
||||
// ListOptions contains options for listing packages
|
||||
type ListOptions struct {
|
||||
Registry string // Filter by registry
|
||||
NamePrefix string // Filter by name prefix
|
||||
MinSize int64 // Minimum package size
|
||||
MaxSize int64 // Maximum package size
|
||||
ScannedOnly bool // Only scanned packages
|
||||
SinceDate time.Time // Packages cached since date
|
||||
Limit int // Max results
|
||||
Offset int // Pagination offset
|
||||
SortBy string // Sort field (name, size, cached_at, download_count)
|
||||
SortDesc bool // Sort descending
|
||||
SinceDate time.Time
|
||||
Registry string
|
||||
NamePrefix string
|
||||
SortBy string
|
||||
MinSize int64
|
||||
MaxSize int64
|
||||
Limit int
|
||||
Offset int
|
||||
ScannedOnly bool
|
||||
SortDesc bool
|
||||
}
|
||||
|
||||
@@ -137,21 +137,11 @@ func RecordCacheMiss(handler string) {
|
||||
CacheRequests.WithLabelValues("miss", handler).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheError records a cache error
|
||||
func RecordCacheError(handler string) {
|
||||
CacheRequests.WithLabelValues("error", handler).Inc()
|
||||
}
|
||||
|
||||
// UpdateCacheSize updates the cache size metric
|
||||
func UpdateCacheSize(backend string, bytes int64) {
|
||||
CacheSizeBytes.WithLabelValues(backend).Set(float64(bytes))
|
||||
}
|
||||
|
||||
// UpdateCacheItems updates the cache items metric
|
||||
func UpdateCacheItems(handler string, count int64) {
|
||||
CacheItemsTotal.WithLabelValues(handler).Set(float64(count))
|
||||
}
|
||||
|
||||
// RecordCacheEviction records a cache eviction
|
||||
func RecordCacheEviction(reason string) {
|
||||
CacheEvictions.WithLabelValues(reason).Inc()
|
||||
@@ -162,26 +152,11 @@ func RecordStorageOperation(backend, operation, status string) {
|
||||
StorageOperations.WithLabelValues(backend, operation, status).Inc()
|
||||
}
|
||||
|
||||
// UpdateStorageQuota updates the storage quota metric
|
||||
func UpdateStorageQuota(project string, bytes int64) {
|
||||
StorageQuotaBytes.WithLabelValues(project).Set(float64(bytes))
|
||||
}
|
||||
|
||||
// RecordUpstreamRequest records an upstream request
|
||||
func RecordUpstreamRequest(registry, status string) {
|
||||
UpstreamRequests.WithLabelValues(registry, status).Inc()
|
||||
}
|
||||
|
||||
// RecordSecurityScan records a security scan
|
||||
func RecordSecurityScan(scanner, result string) {
|
||||
SecurityScans.WithLabelValues(scanner, result).Inc()
|
||||
}
|
||||
|
||||
// RecordVulnerability records a vulnerability finding
|
||||
func RecordVulnerability(severity string) {
|
||||
VulnerabilitiesFound.WithLabelValues(severity).Inc()
|
||||
}
|
||||
|
||||
// UpdateCircuitBreakerState updates the circuit breaker state
|
||||
func UpdateCircuitBreakerState(name string, state int) {
|
||||
CircuitBreakerState.WithLabelValues(name).Set(float64(state))
|
||||
|
||||
@@ -24,23 +24,23 @@ type Client struct {
|
||||
|
||||
// Config holds client configuration
|
||||
type Config struct {
|
||||
Timeout time.Duration // Request timeout
|
||||
MaxRetries int // Max retry attempts
|
||||
RetryDelay time.Duration // Initial retry delay
|
||||
RateLimit float64 // Requests per second (0 = unlimited)
|
||||
RateBurst int // Rate limiter burst
|
||||
CircuitBreaker CircuitBreakerConfig
|
||||
UserAgent string
|
||||
CircuitBreaker CircuitBreakerConfig
|
||||
Timeout time.Duration
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
RateLimit float64
|
||||
RateBurst int
|
||||
MaxConnsPerHost int
|
||||
}
|
||||
|
||||
// RetryConfig holds retry configuration
|
||||
type RetryConfig struct {
|
||||
FixedDelays []time.Duration
|
||||
MaxAttempts int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
FixedDelays []time.Duration // If set, use these delays instead of exponential backoff
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds circuit breaker configuration
|
||||
@@ -63,11 +63,11 @@ const (
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
config CircuitBreakerConfig
|
||||
state CircuitBreakerState
|
||||
failures int
|
||||
successes int
|
||||
lastFailureTime time.Time
|
||||
halfOpenCalls int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -19,14 +19,14 @@ import (
|
||||
// TestClientGet tests the HTTP client Get method with various scenarios
|
||||
func TestClientGet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverBehavior func(*testing.T) *httptest.Server
|
||||
config network.Config
|
||||
headers map[string]string
|
||||
wantErr bool
|
||||
errContains string
|
||||
validateBody func(*testing.T, io.ReadCloser)
|
||||
validateStatus func(*testing.T, int)
|
||||
name string
|
||||
errContains string
|
||||
config network.Config
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Successful GET request
|
||||
{
|
||||
|
||||
@@ -24,22 +24,22 @@ type Worker struct {
|
||||
cache *cache.Manager
|
||||
analytics *analytics.Engine
|
||||
client *network.Client
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
interval time.Duration
|
||||
maxConcurrent int
|
||||
enabled bool
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config holds pre-warming worker configuration
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Interval time.Duration
|
||||
MaxConcurrent int
|
||||
TopPackages int
|
||||
CacheManager *cache.Manager
|
||||
Analytics *analytics.Engine
|
||||
NetworkClient *network.Client
|
||||
Interval time.Duration
|
||||
MaxConcurrent int
|
||||
TopPackages int
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// NewWorker creates a new pre-warming worker
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
)
|
||||
|
||||
// BaseHandler provides common functionality for all proxy handlers
|
||||
type BaseHandler struct {
|
||||
Cache *cache.Manager
|
||||
Client *network.Client
|
||||
Upstream string
|
||||
Registry string
|
||||
}
|
||||
|
||||
// Config holds common proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream registry URL (e.g., registry.npmjs.org)
|
||||
}
|
||||
|
||||
// GetRegistry returns the registry type
|
||||
func (h *BaseHandler) GetRegistry() string {
|
||||
return h.Registry
|
||||
}
|
||||
|
||||
// NewBaseHandler creates a new base handler with common fields
|
||||
func NewBaseHandler(cache *cache.Manager, client *network.Client, registry, upstream string) *BaseHandler {
|
||||
return &BaseHandler{
|
||||
Cache: cache,
|
||||
Client: client,
|
||||
Upstream: upstream,
|
||||
Registry: registry,
|
||||
}
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewBaseHandler tests base handler creation
|
||||
func TestNewBaseHandler(t *testing.T) {
|
||||
// Use nil for cache and client since we're only testing structure
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
require.NotNil(t, handler)
|
||||
assert.Equal(t, "npm", handler.Registry)
|
||||
assert.Equal(t, "https://registry.npmjs.org", handler.Upstream)
|
||||
assert.Nil(t, handler.Cache)
|
||||
assert.Nil(t, handler.Client)
|
||||
}
|
||||
|
||||
// TestGetRegistry tests registry type retrieval
|
||||
func TestGetRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
}{
|
||||
{"npm registry", "npm"},
|
||||
{"pypi registry", "pypi"},
|
||||
{"go registry", "go"},
|
||||
{"custom registry", "custom"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := &BaseHandler{Registry: tt.registry}
|
||||
assert.Equal(t, tt.registry, handler.GetRegistry())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError tests upstream error handling
|
||||
func TestHandleUpstreamError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
url string
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
// GOOD: Standard error
|
||||
{
|
||||
name: "connection error",
|
||||
err: errors.New("connection refused"),
|
||||
url: "https://registry.npmjs.org/react",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
// WRONG: Timeout error
|
||||
{
|
||||
name: "timeout error",
|
||||
err: context.DeadlineExceeded,
|
||||
url: "https://registry.npmjs.org/lodash",
|
||||
context: "metadata",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch metadata",
|
||||
},
|
||||
// EDGE: Empty context
|
||||
{
|
||||
name: "empty context",
|
||||
err: errors.New("error"),
|
||||
url: "https://example.com",
|
||||
context: "",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch",
|
||||
},
|
||||
// EDGE: Long URL
|
||||
{
|
||||
name: "long URL",
|
||||
err: errors.New("error"),
|
||||
url: "https://registry.npmjs.org/@scope/very-long-package-name/versions/1.2.3",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleUpstreamError(w, tt.err, tt.url, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUpstreamStatus tests upstream status validation
|
||||
func TestCheckUpstreamStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body io.ReadCloser
|
||||
wantErr bool
|
||||
errContains string
|
||||
bodyClosed bool
|
||||
}{
|
||||
// GOOD: OK status
|
||||
{
|
||||
name: "200 OK",
|
||||
statusCode: http.StatusOK,
|
||||
body: io.NopCloser(strings.NewReader("success")),
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Not found
|
||||
{
|
||||
name: "404 Not Found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: io.NopCloser(strings.NewReader("not found")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// WRONG: Server error
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: io.NopCloser(strings.NewReader("error")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 500",
|
||||
},
|
||||
// BAD: Unauthorized
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
body: io.NopCloser(strings.NewReader("unauthorized")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 401",
|
||||
},
|
||||
// EDGE: Nil body
|
||||
{
|
||||
name: "nil body with error",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: nil,
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// EDGE: Redirect status
|
||||
{
|
||||
name: "302 Found",
|
||||
statusCode: http.StatusFound,
|
||||
body: io.NopCloser(strings.NewReader("redirect")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 302",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckUpstreamStatus(tt.statusCode, tt.body)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInvalidRequest tests invalid request handling
|
||||
func TestHandleInvalidRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "npm invalid request",
|
||||
registry: "npm",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid npm request",
|
||||
},
|
||||
{
|
||||
name: "pypi invalid request",
|
||||
registry: "pypi",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid pypi request",
|
||||
},
|
||||
{
|
||||
name: "go invalid request",
|
||||
registry: "go",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid go request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInvalidRequest(w, tt.registry)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInternalError tests internal error handling
|
||||
func TestHandleInternalError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "database error",
|
||||
err: errors.New("database connection failed"),
|
||||
context: "database",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: database",
|
||||
},
|
||||
{
|
||||
name: "cache error",
|
||||
err: errors.New("cache write failed"),
|
||||
context: "cache",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInternalError(w, tt.err, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: FetchFromUpstream tests would require mocking cache.Manager and network.Client
|
||||
// which requires concrete implementations. Integration tests cover this functionality.
|
||||
|
||||
// TestWriteResponse tests HTTP response writing
|
||||
func TestWriteResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
contentType string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Write tarball
|
||||
{
|
||||
name: "write tarball",
|
||||
data: "package data here",
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "package data here",
|
||||
wantErr: false,
|
||||
},
|
||||
// GOOD: Write JSON
|
||||
{
|
||||
name: "write JSON metadata",
|
||||
data: `{"name":"react","version":"18.2.0"}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: `{"name":"react","version":"18.2.0"}`,
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Empty data
|
||||
{
|
||||
name: "empty data",
|
||||
data: "",
|
||||
contentType: "text/plain",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Large data
|
||||
{
|
||||
name: "large data",
|
||||
data: strings.Repeat("x", 100000),
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: strings.Repeat("x", 100000),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(tt.data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, tt.contentType)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.contentType, w.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseHandlerFields tests that BaseHandler fields are properly set
|
||||
func TestBaseHandlerFields(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
expected interface{}
|
||||
}{
|
||||
{"registry field", "registry", "npm"},
|
||||
{"upstream field", "upstream", "https://registry.npmjs.org"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
switch tt.field {
|
||||
case "registry":
|
||||
assert.Equal(t, tt.expected, handler.Registry)
|
||||
case "upstream":
|
||||
assert.Equal(t, tt.expected, handler.Upstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyHandlerInterface tests that BaseHandler can be used as ProxyHandler
|
||||
func TestProxyHandlerInterface(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
// Verify GetRegistry works
|
||||
registry := handler.GetRegistry()
|
||||
assert.Equal(t, "npm", registry)
|
||||
}
|
||||
|
||||
// TestConcurrentWriteResponse tests that WriteResponse is safe for concurrent use
|
||||
func TestConcurrentWriteResponse(t *testing.T) {
|
||||
const numGoroutines = 10
|
||||
|
||||
errs := make(chan error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(n int) {
|
||||
w := httptest.NewRecorder()
|
||||
data := strings.Repeat("x", 1000)
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, "text/plain")
|
||||
errs <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errs
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// HandleUpstreamError logs an error and sends an HTTP 502 Bad Gateway response
|
||||
// This is the common pattern used across all proxy handlers when upstream fetch fails
|
||||
func HandleUpstreamError(w http.ResponseWriter, err error, url, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", url).
|
||||
Str("context", context).
|
||||
Msg("Failed to fetch from upstream")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Failed to fetch %s", context), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// CheckUpstreamStatus validates HTTP status code from upstream
|
||||
// Returns error if status is not OK, closing body if needed
|
||||
func CheckUpstreamStatus(statusCode int, body io.ReadCloser) error {
|
||||
if statusCode != http.StatusOK {
|
||||
if body != nil {
|
||||
body.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
}
|
||||
return fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleInvalidRequest sends a 400 Bad Request response for invalid proxy requests
|
||||
func HandleInvalidRequest(w http.ResponseWriter, registry string) {
|
||||
http.Error(w, fmt.Sprintf("Invalid %s request", registry), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// HandleInternalError logs an internal error and sends 500 response
|
||||
func HandleInternalError(w http.ResponseWriter, err error, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("context", context).
|
||||
Msg("Internal error processing request")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Internal error: %s", context), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FetchFromUpstream is a common helper to fetch content from upstream with caching
|
||||
// This encapsulates the common pattern of: cache.Get -> network.Get -> error handling
|
||||
func FetchFromUpstream(
|
||||
ctx context.Context,
|
||||
cacheManager *cache.Manager,
|
||||
client *network.Client,
|
||||
registry, name, version, upstreamURL string,
|
||||
) (*cache.CacheEntry, error) {
|
||||
entry, err := cacheManager.Get(ctx, registry, name, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := client.Get(ctx, upstreamURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := CheckUpstreamStatus(statusCode, body); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return body, upstreamURL, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", upstreamURL).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to fetch package from upstream")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// WriteResponse writes the cache entry data to the HTTP response writer
|
||||
// Sets appropriate content type and handles errors
|
||||
func WriteResponse(w http.ResponseWriter, entry *cache.CacheEntry, contentType string) error {
|
||||
defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if _, err := io.Copy(w, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to write response")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProxyHandler defines the common interface for all registry proxies
|
||||
type ProxyHandler interface {
|
||||
http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// GetRegistry returns the registry type (npm, pypi, go)
|
||||
GetRegistry() string
|
||||
|
||||
// Health checks if the proxy can reach its upstream
|
||||
Health(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Stats represents proxy statistics
|
||||
type Stats struct {
|
||||
Registry string
|
||||
TotalRequests int64
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
UpstreamErrors int64
|
||||
AvgResponseTime time.Duration
|
||||
LastUpdated time.Time
|
||||
}
|
||||
@@ -20,21 +20,21 @@ import (
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
sumDBURL string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.GoValidator
|
||||
validationCache *auth.ValidationCache
|
||||
gitFetcher *vcs.GitFetcher
|
||||
moduleBuilder *vcs.ModuleBuilder
|
||||
upstream string
|
||||
sumDBURL string
|
||||
}
|
||||
|
||||
// Config holds Go proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream Go proxy (e.g., proxy.golang.org)
|
||||
SumDBURL string // Checksum database URL
|
||||
CredStore *vcs.CredentialStore // Optional credential store for git access
|
||||
CredStore *vcs.CredentialStore
|
||||
Upstream string
|
||||
SumDBURL string
|
||||
}
|
||||
|
||||
// New creates a new Go proxy handler
|
||||
|
||||
@@ -21,11 +21,11 @@ import (
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.NPMValidator
|
||||
validationCache *auth.ValidationCache
|
||||
upstream string
|
||||
}
|
||||
|
||||
// Config holds NPM proxy configuration
|
||||
|
||||
@@ -21,11 +21,11 @@ import (
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
credExtractor *auth.CredentialExtractor
|
||||
credHasher *auth.CredentialHasher
|
||||
credValidator *auth.PyPIValidator
|
||||
validationCache *auth.ValidationCache
|
||||
upstream string
|
||||
}
|
||||
|
||||
// Config holds PyPI proxy configuration
|
||||
|
||||
@@ -20,8 +20,8 @@ const ScannerName = "github-advisory-database"
|
||||
|
||||
// Scanner implements the GitHub Advisory Database vulnerability scanner
|
||||
type Scanner struct {
|
||||
config config.GHSAConfig
|
||||
httpClient *http.Client
|
||||
config config.GHSAConfig
|
||||
}
|
||||
|
||||
// New creates a new GitHub Advisory Database scanner
|
||||
@@ -257,10 +257,10 @@ type GHSAAdvisory struct {
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
References []GHSAReference `json:"references"`
|
||||
Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
References []GHSAReference `json:"references"`
|
||||
Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"`
|
||||
}
|
||||
|
||||
type GHSAReference struct {
|
||||
@@ -268,9 +268,9 @@ type GHSAReference struct {
|
||||
}
|
||||
|
||||
type GHSAVulnerability struct {
|
||||
FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"`
|
||||
Package GHSAPackage `json:"package"`
|
||||
VulnerableVersions string `json:"vulnerable_version_range"`
|
||||
FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"`
|
||||
}
|
||||
|
||||
type GHSAPackage struct {
|
||||
|
||||
@@ -153,9 +153,9 @@ func (s *Scanner) convertGrypeResult(grypeResult *GrypeResult, registry, package
|
||||
|
||||
// GrypeResult represents Grype JSON output structure
|
||||
type GrypeResult struct {
|
||||
Matches []GrypeMatch `json:"matches"`
|
||||
Descriptor GrypeDescriptor `json:"descriptor"`
|
||||
Source GrypeSource `json:"source"`
|
||||
Descriptor GrypeDescriptor `json:"descriptor"`
|
||||
Matches []GrypeMatch `json:"matches"`
|
||||
}
|
||||
|
||||
type GrypeDescriptor struct {
|
||||
@@ -164,13 +164,13 @@ type GrypeDescriptor struct {
|
||||
}
|
||||
|
||||
type GrypeSource struct {
|
||||
Type string `json:"type"`
|
||||
Target map[string]interface{} `json:"target"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type GrypeMatch struct {
|
||||
Vulnerability GrypeVulnerability `json:"vulnerability"`
|
||||
Artifact GrypeArtifact `json:"artifact"`
|
||||
Vulnerability GrypeVulnerability `json:"vulnerability"`
|
||||
}
|
||||
|
||||
type GrypeVulnerability struct {
|
||||
|
||||
@@ -199,9 +199,9 @@ func (s *Scanner) convertResult(auditResult *NpmAuditResult, registry, packageNa
|
||||
|
||||
// NpmAuditResult represents npm audit JSON output
|
||||
type NpmAuditResult struct {
|
||||
AuditReportVersion int `json:"auditReportVersion"`
|
||||
Vulnerabilities map[string]NpmVulnerability `json:"vulnerabilities"`
|
||||
Metadata NpmAuditMetadata `json:"metadata"`
|
||||
AuditReportVersion int `json:"auditReportVersion"`
|
||||
}
|
||||
|
||||
type NpmVulnerability struct {
|
||||
|
||||
@@ -25,8 +25,8 @@ const (
|
||||
|
||||
// Scanner implements the Scanner interface using OSV.dev API
|
||||
type Scanner struct {
|
||||
config config.OSVConfig
|
||||
httpClient *http.Client
|
||||
config config.OSVConfig
|
||||
}
|
||||
|
||||
// OSVRequest represents the request structure for OSV API
|
||||
@@ -48,13 +48,13 @@ type OSVResponse struct {
|
||||
|
||||
// OSVVulnerability represents a vulnerability in OSV format
|
||||
type OSVVulnerability struct {
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Summary string `json:"summary"`
|
||||
Details string `json:"details"`
|
||||
Severity []OSVSeverity `json:"severity,omitempty"`
|
||||
References []OSVReference `json:"references,omitempty"`
|
||||
Affected []OSVAffected `json:"affected"`
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
}
|
||||
|
||||
// OSVSeverity represents severity information
|
||||
@@ -71,11 +71,11 @@ type OSVReference struct {
|
||||
|
||||
// OSVAffected represents affected package versions
|
||||
type OSVAffected struct {
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"`
|
||||
Package PackageInfo `json:"package"`
|
||||
Ranges []OSVRange `json:"ranges,omitempty"`
|
||||
Versions []string `json:"versions,omitempty"`
|
||||
DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"`
|
||||
EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"`
|
||||
}
|
||||
|
||||
// OSVRange represents version ranges
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
|
||||
// RescanWorker handles periodic re-scanning of cached packages
|
||||
type RescanWorker struct {
|
||||
manager *Manager
|
||||
metadataStore metadata.MetadataStore
|
||||
storage storage.StorageBackend
|
||||
interval time.Duration
|
||||
manager *Manager
|
||||
stopCh chan struct{}
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewRescanWorker creates a new rescan worker
|
||||
|
||||
@@ -36,10 +36,10 @@ type DatabaseUpdater interface {
|
||||
|
||||
// Manager manages multiple security scanners
|
||||
type Manager struct {
|
||||
metadataStore metadata.MetadataStore
|
||||
config config.SecurityConfig
|
||||
scanners []Scanner
|
||||
enabled bool
|
||||
config config.SecurityConfig
|
||||
metadataStore metadata.MetadataStore
|
||||
}
|
||||
|
||||
// New creates a new scanner manager with configured scanners
|
||||
|
||||
@@ -25,18 +25,18 @@ type Scanner struct {
|
||||
|
||||
// TrivyResult represents Trivy JSON output structure
|
||||
type TrivyResult struct {
|
||||
SchemaVersion int `json:"SchemaVersion"`
|
||||
Metadata TrivyMetadata `json:"Metadata"`
|
||||
ArtifactName string `json:"ArtifactName"`
|
||||
ArtifactType string `json:"ArtifactType"`
|
||||
Metadata TrivyMetadata `json:"Metadata"`
|
||||
Results []TrivyVulnResult `json:"Results"`
|
||||
SchemaVersion int `json:"SchemaVersion"`
|
||||
}
|
||||
|
||||
type TrivyMetadata struct {
|
||||
OS *TrivyOS `json:"OS,omitempty"`
|
||||
ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"`
|
||||
RepoTags []string `json:"RepoTags,omitempty"`
|
||||
RepoDigests []string `json:"RepoDigests,omitempty"`
|
||||
ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"`
|
||||
}
|
||||
|
||||
type TrivyOS struct {
|
||||
@@ -64,8 +64,8 @@ type TrivyVulnerability struct {
|
||||
Severity string `json:"Severity"`
|
||||
Title string `json:"Title"`
|
||||
Description string `json:"Description"`
|
||||
References []string `json:"References"`
|
||||
PrimaryURL string `json:"PrimaryURL"`
|
||||
References []string `json:"References"`
|
||||
}
|
||||
|
||||
// New creates a new Trivy scanner
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/config"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/health"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/logger"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
)
|
||||
|
||||
// Server wraps http.Server with configuration
|
||||
type Server struct {
|
||||
*http.Server
|
||||
config *config.Config
|
||||
healthChecker *health.Checker
|
||||
}
|
||||
|
||||
// New creates a new HTTP server
|
||||
func New(cfg *config.Config, healthChecker *health.Checker) (*Server, error) {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register routes
|
||||
registerRoutes(mux, cfg, healthChecker)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
Handler: logger.Middleware(mux),
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
return &Server{
|
||||
Server: srv,
|
||||
config: cfg,
|
||||
healthChecker: healthChecker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// registerRoutes registers all HTTP routes
|
||||
func registerRoutes(mux *http.ServeMux, cfg *config.Config, healthChecker *health.Checker) {
|
||||
// Health endpoints
|
||||
mux.HandleFunc("/health", healthChecker.HealthHandler())
|
||||
mux.HandleFunc("/health/ready", healthChecker.ReadyHandler())
|
||||
|
||||
// Metrics endpoint
|
||||
mux.Handle("/metrics", metrics.Handler())
|
||||
|
||||
// API endpoints
|
||||
mux.HandleFunc("/api/v1/info", handleInfo(cfg))
|
||||
|
||||
// Package manager proxy endpoints (placeholders for now)
|
||||
if cfg.Handlers.Go.Enabled {
|
||||
mux.HandleFunc("/go/", handleGoProxy())
|
||||
}
|
||||
if cfg.Handlers.NPM.Enabled {
|
||||
mux.HandleFunc("/npm/", handleNPMProxy())
|
||||
}
|
||||
if cfg.Handlers.PyPI.Enabled {
|
||||
mux.HandleFunc("/pypi/", handlePyPIProxy())
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
mux.HandleFunc("/", handleRoot())
|
||||
}
|
||||
|
||||
// handleInfo returns server information
|
||||
func handleInfo(cfg *config.Config) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
info := map[string]interface{}{
|
||||
"name": "GoHoarder",
|
||||
"version": "dev",
|
||||
"handlers": map[string]bool{
|
||||
"go": cfg.Handlers.Go.Enabled,
|
||||
"npm": cfg.Handlers.NPM.Enabled,
|
||||
"pypi": cfg.Handlers.PyPI.Enabled,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"success":true,"data":%v}`, toJSON(info))
|
||||
}
|
||||
}
|
||||
|
||||
// handleGoProxy handles Go module proxy requests
|
||||
func handleGoProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement Go proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"Go proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNPMProxy handles NPM registry requests
|
||||
func handleNPMProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement NPM proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"NPM proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePyPIProxy handles PyPI requests
|
||||
func handlePyPIProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: Implement PyPI proxy handler
|
||||
http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"PyPI proxy not yet implemented"}}`, http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
// handleRoot handles root path
|
||||
func handleRoot() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, `{"success":true,"data":{"message":"GoHoarder - Universal Package Cache Proxy","docs":"https://github.com/lukaszraczylo/gohoarder"}}`)
|
||||
}
|
||||
}
|
||||
|
||||
// toJSON is a simple JSON encoder (replace with proper implementation)
|
||||
func toJSON(v interface{}) string {
|
||||
// Simplified for now - proper implementation would use goccy/go-json
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
|
||||
type FilesystemStorageTestSuite struct {
|
||||
suite.Suite
|
||||
tempDir string
|
||||
fs *FilesystemStorage
|
||||
tempDir string
|
||||
}
|
||||
|
||||
func (s *FilesystemStorageTestSuite) SetupTest() {
|
||||
@@ -46,12 +46,12 @@ func TestFilesystemStorageTestSuite(t *testing.T) {
|
||||
// Test Put operation
|
||||
func (s *FilesystemStorageTestSuite) TestPut() {
|
||||
tests := []struct {
|
||||
opts *storage.PutOptions
|
||||
errorCheck func(error) bool
|
||||
name string
|
||||
key string
|
||||
data string
|
||||
opts *storage.PutOptions
|
||||
expectError bool
|
||||
errorCheck func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "successful put",
|
||||
@@ -122,8 +122,8 @@ func (s *FilesystemStorageTestSuite) TestGet() {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expectError bool
|
||||
expectData string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "get existing file",
|
||||
@@ -258,11 +258,11 @@ func (s *FilesystemStorageTestSuite) TestList() {
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
opts *storage.ListOptions
|
||||
name string
|
||||
prefix string
|
||||
opts *storage.ListOptions
|
||||
expectedCount int
|
||||
expectedKeys []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "list all npm packages",
|
||||
@@ -422,8 +422,8 @@ func (s *FilesystemStorageTestSuite) TestContextCancellation() {
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "Get with cancelled context",
|
||||
@@ -679,8 +679,8 @@ func (s *FilesystemStorageTestSuite) TestChecksumValidation() {
|
||||
correctMD5 := "7dd7323e8ce3e087972f93d3711ef62b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *storage.PutOptions
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -52,21 +52,21 @@ type ListOptions struct {
|
||||
|
||||
// StorageObject represents a stored object
|
||||
type StorageObject struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
Key string
|
||||
ETag string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// StorageInfo contains detailed object information
|
||||
type StorageInfo struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
ETag string
|
||||
ContentType string
|
||||
Metadata map[string]string
|
||||
Checksums *Checksums
|
||||
Key string
|
||||
ETag string
|
||||
ContentType string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Checksums contains file checksums
|
||||
|
||||
+189
-244
@@ -3,14 +3,10 @@ package s3
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for S3 Content-MD5 header, not cryptographic security
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
@@ -18,261 +14,210 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// S3Storage implements storage.StorageBackend for AWS S3
|
||||
type S3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
}
|
||||
|
||||
// Config holds S3 configuration
|
||||
// Config holds S3 storage configuration
|
||||
type Config struct {
|
||||
Bucket string
|
||||
Region string
|
||||
Endpoint string // For S3-compatible services (MinIO, etc.)
|
||||
Bucket string
|
||||
Prefix string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
Prefix string // Optional prefix for all keys
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
ForcePathStyle bool // For S3-compatible services
|
||||
Endpoint string // Optional: for S3-compatible services like MinIO
|
||||
ForcePathStyle bool // Optional: for S3-compatible services
|
||||
MaxSizeBytes int64
|
||||
}
|
||||
|
||||
// S3Storage implements storage.StorageBackend using AWS S3
|
||||
type S3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
maxSizeBytes int64
|
||||
}
|
||||
|
||||
// New creates a new S3 storage backend
|
||||
func New(ctx context.Context, cfg Config) (*S3Storage, error) {
|
||||
func New(cfg Config) (*S3Storage, error) {
|
||||
if cfg.Bucket == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 bucket is required")
|
||||
return nil, fmt.Errorf("S3 bucket is required")
|
||||
}
|
||||
|
||||
if cfg.Region == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 region is required")
|
||||
cfg.Region = "us-east-1" // Default region
|
||||
}
|
||||
|
||||
// Build AWS config
|
||||
var awsCfg aws.Config
|
||||
var awsConfig aws.Config
|
||||
var err error
|
||||
|
||||
// Build config options
|
||||
configOpts := []func(*config.LoadOptions) error{
|
||||
config.WithRegion(cfg.Region),
|
||||
}
|
||||
|
||||
// Add credentials if provided
|
||||
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
|
||||
// Use static credentials
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
configOpts = append(configOpts, config.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(
|
||||
cfg.AccessKeyID,
|
||||
cfg.SecretAccessKey,
|
||||
"",
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
// Use default credential chain
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
)
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
awsConfig, err = config.LoadDefaultConfig(context.Background(), configOpts...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to load AWS config")
|
||||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
var s3Options []func(*s3.Options)
|
||||
|
||||
if cfg.Endpoint != "" {
|
||||
s3Options = append(s3Options, func(o *s3.Options) {
|
||||
// Create S3 client with service-specific options
|
||||
client := s3.NewFromConfig(awsConfig, func(o *s3.Options) {
|
||||
// Use custom endpoint if provided (for MinIO, S3-compatible services, etc.)
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||||
o.UsePathStyle = cfg.ForcePathStyle
|
||||
})
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
})
|
||||
|
||||
storage := &S3Storage{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: strings.TrimSuffix(cfg.Prefix, "/"),
|
||||
maxSizeBytes: cfg.MaxSizeBytes,
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, s3Options...)
|
||||
log.Info().
|
||||
Str("bucket", cfg.Bucket).
|
||||
Str("region", cfg.Region).
|
||||
Str("prefix", cfg.Prefix).
|
||||
Msg("S3 storage initialized")
|
||||
|
||||
s3Storage := &S3Storage{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: strings.TrimSuffix(cfg.Prefix, "/"),
|
||||
quota: cfg.Quota,
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := s3Storage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial S3 storage usage")
|
||||
}
|
||||
|
||||
return s3Storage, nil
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// Get retrieves a file from S3
|
||||
// Get retrieves data from S3
|
||||
func (s *S3Storage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
fullKey := s.buildKey(key)
|
||||
|
||||
input := &s3.GetObjectInput{
|
||||
log.Debug().Str("key", fullKey).Msg("Getting object from S3")
|
||||
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
Key: aws.String(fullKey),
|
||||
})
|
||||
|
||||
result, err := s.client.GetObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
metrics.RecordStorageOperation("s3", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
return nil, errors.NotFound(fmt.Sprintf("S3 object not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("s3", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get object from S3")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("s3", "get", "success")
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
// Put stores a file in S3
|
||||
// Put stores data in S3
|
||||
func (s *S3Storage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
s3Key := s.buildKey(key)
|
||||
fullKey := s.buildKey(key)
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New() // #nosec G401 -- MD5 used for S3 integrity check, not cryptographic security
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
// Read data into buffer to get size
|
||||
buf := new(bytes.Buffer)
|
||||
size, err := io.Copy(buf, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
return fmt.Errorf("failed to read data: %w", err)
|
||||
}
|
||||
|
||||
// Check quota before upload
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
log.Debug().
|
||||
Str("key", fullKey).
|
||||
Int64("size", size).
|
||||
Msg("Putting object to S3")
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("s3", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
// Check quota if set
|
||||
if s.maxSizeBytes > 0 {
|
||||
currentUsage, err := s.calculateUsage(ctx)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate current usage, skipping quota check")
|
||||
} else if currentUsage+size > s.maxSizeBytes {
|
||||
return errors.QuotaExceeded(s.maxSizeBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare metadata
|
||||
metadata := make(map[string]string)
|
||||
// Convert metadata to S3 metadata format
|
||||
s3Metadata := make(map[string]string)
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
metadata = opts.Metadata
|
||||
}
|
||||
|
||||
// Build put input
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
Body: bytes.NewReader(buf.Bytes()),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
if opts != nil && opts.ContentType != "" {
|
||||
input.ContentType = aws.String(opts.ContentType)
|
||||
for k, v := range opts.Metadata {
|
||||
s3Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Upload to S3
|
||||
_, err = s.client.PutObject(ctx, input)
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(fullKey),
|
||||
Body: bytes.NewReader(buf.Bytes()),
|
||||
Metadata: s3Metadata,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to upload to S3")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to put object to S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "put", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from S3
|
||||
// Delete removes data from S3
|
||||
func (s *S3Storage) Delete(ctx context.Context, key string) error {
|
||||
s3Key := s.buildKey(key)
|
||||
fullKey := s.buildKey(key)
|
||||
|
||||
// Get size before deletion for quota tracking
|
||||
statInfo, err := s.Stat(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug().Str("key", fullKey).Msg("Deleting object from S3")
|
||||
|
||||
input := &s3.DeleteObjectInput{
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
Key: aws.String(fullKey),
|
||||
})
|
||||
|
||||
_, err = s.client.DeleteObject(ctx, input)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete from S3")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete object from S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= statInfo.Size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "delete", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists in S3
|
||||
// Exists checks if data exists in S3
|
||||
func (s *S3Storage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
fullKey := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
Key: aws.String(fullKey),
|
||||
})
|
||||
|
||||
_, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence in S3")
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check object existence in S3")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix in S3
|
||||
// List returns a list of objects with the given prefix
|
||||
func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
s3Prefix := s.buildKey(prefix)
|
||||
fullPrefix := s.buildKey(prefix)
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(s3Prefix),
|
||||
}
|
||||
log.Debug().Str("prefix", fullPrefix).Msg("Listing objects in S3")
|
||||
|
||||
var objects []storage.StorageObject
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(fullPrefix),
|
||||
})
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
@@ -281,56 +226,58 @@ func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListO
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
key := s.stripPrefix(*obj.Key)
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: *obj.Size,
|
||||
Modified: *obj.LastModified,
|
||||
ETag: strings.Trim(*obj.ETag, "\""),
|
||||
})
|
||||
}
|
||||
}
|
||||
if obj.Key != nil {
|
||||
// Strip prefix from key
|
||||
key := s.stripPrefix(*obj.Key)
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
object := storage.StorageObject{
|
||||
Key: key,
|
||||
Size: aws.ToInt64(obj.Size),
|
||||
}
|
||||
|
||||
if obj.LastModified != nil {
|
||||
object.Modified = *obj.LastModified
|
||||
}
|
||||
|
||||
if obj.ETag != nil {
|
||||
object.ETag = *obj.ETag
|
||||
}
|
||||
|
||||
objects = append(objects, object)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from S3
|
||||
// Stat returns metadata about stored data
|
||||
func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
fullKey := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
result, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
Key: aws.String(fullKey),
|
||||
})
|
||||
|
||||
result, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
return nil, errors.NotFound(fmt.Sprintf("S3 object not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat object in S3")
|
||||
}
|
||||
|
||||
info := &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: *result.ContentLength,
|
||||
Modified: *result.LastModified,
|
||||
ETag: strings.Trim(*result.ETag, "\""),
|
||||
Metadata: result.Metadata,
|
||||
Key: key,
|
||||
Size: aws.ToInt64(result.ContentLength),
|
||||
}
|
||||
|
||||
if result.LastModified != nil {
|
||||
info.Modified = *result.LastModified
|
||||
}
|
||||
|
||||
if result.ETag != nil {
|
||||
info.ETag = *result.ETag
|
||||
}
|
||||
|
||||
if result.ContentType != nil {
|
||||
@@ -340,33 +287,27 @@ func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo,
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
// GetQuota returns current usage and quota information
|
||||
func (s *S3Storage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
usage, err := s.calculateUsage(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
Used: usage,
|
||||
Limit: s.maxSizeBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks S3 health
|
||||
// Health checks if the S3 backend is healthy
|
||||
func (s *S3Storage) Health(ctx context.Context) error {
|
||||
// Try to list bucket to verify connectivity
|
||||
input := &s3.ListObjectsV2Input{
|
||||
// Try to list objects (lightweight operation)
|
||||
_, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
MaxKeys: aws.Int32(1),
|
||||
}
|
||||
})
|
||||
|
||||
_, err := s.client.ListObjectsV2(ctx, input)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "S3 health check failed")
|
||||
}
|
||||
@@ -374,60 +315,51 @@ func (s *S3Storage) Health(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
// Close closes the S3 storage backend
|
||||
func (s *S3Storage) Close() error {
|
||||
// No cleanup needed for S3 client
|
||||
log.Info().Msg("S3 storage closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildKey builds the full S3 key with prefix
|
||||
// buildKey constructs the full S3 key with prefix
|
||||
func (s *S3Storage) buildKey(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
if s.prefix != "" {
|
||||
return s.prefix + "/" + key
|
||||
if s.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return key
|
||||
return s.prefix + "/" + key
|
||||
}
|
||||
|
||||
// stripPrefix removes the configured prefix from an S3 key
|
||||
func (s *S3Storage) stripPrefix(s3Key string) string {
|
||||
if s.prefix != "" {
|
||||
return strings.TrimPrefix(s3Key, s.prefix+"/")
|
||||
// stripPrefix removes the prefix from an S3 key
|
||||
func (s *S3Storage) stripPrefix(key string) string {
|
||||
if s.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return s3Key
|
||||
return strings.TrimPrefix(key, s.prefix+"/")
|
||||
}
|
||||
|
||||
// calculateUsage calculates current S3 storage usage
|
||||
func (s *S3Storage) calculateUsage(ctx context.Context) error {
|
||||
var total int64
|
||||
// calculateUsage calculates total storage usage
|
||||
func (s *S3Storage) calculateUsage(ctx context.Context) (int64, error) {
|
||||
var totalSize int64
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
}
|
||||
|
||||
if s.prefix != "" {
|
||||
input.Prefix = aws.String(s.prefix + "/")
|
||||
}
|
||||
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
Prefix: aws.String(s.prefix),
|
||||
})
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, fmt.Errorf("failed to calculate usage: %w", err)
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
total += *obj.Size
|
||||
if obj.Size != nil {
|
||||
totalSize += aws.ToInt64(obj.Size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("s3", total)
|
||||
return nil
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
// isNotFoundError checks if an error is a "not found" error
|
||||
@@ -436,8 +368,21 @@ func isNotFoundError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for specific S3 error types
|
||||
var notFound *types.NotFound
|
||||
var noSuchKey *types.NoSuchKey
|
||||
|
||||
return stderrors.As(err, ¬Found) || stderrors.As(err, &noSuchKey)
|
||||
// Use errors.As to check for wrapped errors
|
||||
if ok := stderrors.As(err, ¬Found); ok {
|
||||
return true
|
||||
}
|
||||
if ok := stderrors.As(err, &noSuchKey); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message as fallback
|
||||
errMsg := err.Error()
|
||||
return strings.Contains(errMsg, "NoSuchKey") ||
|
||||
strings.Contains(errMsg, "NotFound") ||
|
||||
strings.Contains(errMsg, "404")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type S3StorageTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestS3StorageTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(S3StorageTestSuite))
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestNewS3Storage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config with credentials",
|
||||
config: Config{
|
||||
Region: "us-east-1",
|
||||
Bucket: "test-bucket",
|
||||
Prefix: "packages/",
|
||||
AccessKeyID: "AKIAIOSFODNN7EXAMPLE",
|
||||
SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
MaxSizeBytes: 1024 * 1024,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with custom endpoint",
|
||||
config: Config{
|
||||
Region: "us-east-1",
|
||||
Bucket: "test-bucket",
|
||||
Endpoint: "https://minio.example.com",
|
||||
AccessKeyID: "minioadmin",
|
||||
SecretAccessKey: "minioadmin",
|
||||
ForcePathStyle: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with default region",
|
||||
config: Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "test",
|
||||
SecretAccessKey: "test",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
config: Config{
|
||||
Region: "us-east-1",
|
||||
AccessKeyID: "test",
|
||||
SecretAccessKey: "test",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "bucket is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage, err := New(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
if tt.errorMsg != "" {
|
||||
s.Contains(err.Error(), tt.errorMsg)
|
||||
}
|
||||
s.Nil(storage)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(storage)
|
||||
s.Equal(tt.config.Bucket, storage.bucket)
|
||||
s.Equal(tt.config.MaxSizeBytes, storage.maxSizeBytes)
|
||||
|
||||
// Test prefix normalization
|
||||
if tt.config.Prefix != "" {
|
||||
s.NotContains(storage.prefix, "/", "prefix should not end with /")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestBuildKey() {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with prefix",
|
||||
prefix: "packages",
|
||||
key: "test/file.txt",
|
||||
expected: "packages/test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "without prefix",
|
||||
prefix: "",
|
||||
key: "test/file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "with trailing slash in prefix",
|
||||
prefix: "packages/",
|
||||
key: "test/file.txt",
|
||||
expected: "packages/test/file.txt",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage := &S3Storage{
|
||||
prefix: tt.prefix,
|
||||
}
|
||||
// Normalize prefix like in New()
|
||||
if storage.prefix != "" && storage.prefix[len(storage.prefix)-1] == '/' {
|
||||
storage.prefix = storage.prefix[:len(storage.prefix)-1]
|
||||
}
|
||||
|
||||
result := storage.buildKey(tt.key)
|
||||
s.Equal(tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestStripPrefix() {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with prefix",
|
||||
prefix: "packages",
|
||||
key: "packages/test/file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "without prefix",
|
||||
prefix: "",
|
||||
key: "test/file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "key without prefix but prefix set",
|
||||
prefix: "packages",
|
||||
key: "test/file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage := &S3Storage{
|
||||
prefix: tt.prefix,
|
||||
}
|
||||
|
||||
result := storage.stripPrefix(tt.key)
|
||||
s.Equal(tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestIsNotFoundError() {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
result := isNotFoundError(tt.err)
|
||||
s.Equal(tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestConfigDefaults() {
|
||||
config := Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "test",
|
||||
SecretAccessKey: "test",
|
||||
}
|
||||
|
||||
storage, err := New(config)
|
||||
s.Require().NoError(err)
|
||||
s.NotNil(storage)
|
||||
|
||||
// Verify defaults
|
||||
s.Equal("test-bucket", storage.bucket)
|
||||
s.Equal("", storage.prefix)
|
||||
s.Equal(int64(0), storage.maxSizeBytes)
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestPrefixNormalization() {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputPrefix string
|
||||
expectedPrefix string
|
||||
}{
|
||||
{
|
||||
name: "prefix with trailing slash",
|
||||
inputPrefix: "packages/",
|
||||
expectedPrefix: "packages",
|
||||
},
|
||||
{
|
||||
name: "prefix without trailing slash",
|
||||
inputPrefix: "packages",
|
||||
expectedPrefix: "packages",
|
||||
},
|
||||
{
|
||||
name: "empty prefix",
|
||||
inputPrefix: "",
|
||||
expectedPrefix: "",
|
||||
},
|
||||
{
|
||||
name: "nested prefix with trailing slash",
|
||||
inputPrefix: "cache/packages/",
|
||||
expectedPrefix: "cache/packages",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
config := Config{
|
||||
Bucket: "test-bucket",
|
||||
Prefix: tt.inputPrefix,
|
||||
AccessKeyID: "test",
|
||||
SecretAccessKey: "test",
|
||||
}
|
||||
|
||||
storage, err := New(config)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(tt.expectedPrefix, storage.prefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3StorageTestSuite) TestClose() {
|
||||
config := Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "test",
|
||||
SecretAccessKey: "test",
|
||||
}
|
||||
|
||||
storage, err := New(config)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Close should not error
|
||||
err = storage.Close()
|
||||
s.NoError(err)
|
||||
}
|
||||
+266
-303
@@ -1,42 +1,43 @@
|
||||
package smb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5" // #nosec G501 -- MD5 used for file checksums, not cryptographic security
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SMBStorage implements storage.StorageBackend for SMB/CIFS shares
|
||||
type SMBStorage struct {
|
||||
host string
|
||||
share string
|
||||
basePath string
|
||||
username string
|
||||
password string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
connPool chan *smbConnection
|
||||
poolSize int
|
||||
// Config holds SMB storage configuration
|
||||
type Config struct {
|
||||
Host string
|
||||
Share string
|
||||
Path string
|
||||
Username string
|
||||
Password string
|
||||
Domain string
|
||||
Port int
|
||||
MaxSizeBytes int64
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
// smbConnection wraps an SMB session and share
|
||||
// SMBStorage implements storage.StorageBackend using SMB/CIFS
|
||||
type SMBStorage struct {
|
||||
connPool chan *smbConnection
|
||||
config Config
|
||||
maxSizeBytes int64
|
||||
poolSize int
|
||||
}
|
||||
|
||||
// smbConnection represents a pooled SMB connection
|
||||
type smbConnection struct {
|
||||
conn net.Conn
|
||||
session *smb2.Session
|
||||
@@ -44,27 +45,14 @@ type smbConnection struct {
|
||||
lastUse time.Time
|
||||
}
|
||||
|
||||
// Config holds SMB configuration
|
||||
type Config struct {
|
||||
Host string // SMB server hostname or IP
|
||||
Port int // SMB server port (default: 445)
|
||||
Share string // SMB share name
|
||||
BasePath string // Base path within the share
|
||||
Username string // SMB username
|
||||
Password string // SMB password
|
||||
Domain string // SMB domain (optional)
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
PoolSize int // Connection pool size (default: 5)
|
||||
}
|
||||
|
||||
// New creates a new SMB storage backend
|
||||
func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
|
||||
func New(cfg Config) (*SMBStorage, error) {
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required")
|
||||
return nil, fmt.Errorf("SMB host is required")
|
||||
}
|
||||
|
||||
if cfg.Share == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required")
|
||||
return nil, fmt.Errorf("SMB share is required")
|
||||
}
|
||||
|
||||
if cfg.Port == 0 {
|
||||
@@ -75,64 +63,68 @@ func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
|
||||
cfg.PoolSize = 5 // Default pool size
|
||||
}
|
||||
|
||||
smbStorage := &SMBStorage{
|
||||
host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
share: cfg.Share,
|
||||
basePath: strings.Trim(cfg.BasePath, "/\\"),
|
||||
username: cfg.Username,
|
||||
password: cfg.Password,
|
||||
quota: cfg.Quota,
|
||||
connPool: make(chan *smbConnection, cfg.PoolSize),
|
||||
poolSize: cfg.PoolSize,
|
||||
// Normalize path
|
||||
cfg.Path = strings.Trim(cfg.Path, "/\\")
|
||||
|
||||
storage := &SMBStorage{
|
||||
config: cfg,
|
||||
maxSizeBytes: cfg.MaxSizeBytes,
|
||||
poolSize: cfg.PoolSize,
|
||||
connPool: make(chan *smbConnection, cfg.PoolSize),
|
||||
}
|
||||
|
||||
// Initialize connection pool
|
||||
// Pre-populate connection pool
|
||||
for i := 0; i < cfg.PoolSize; i++ {
|
||||
conn, err := smbStorage.createConnection(ctx)
|
||||
conn, err := storage.createConnection()
|
||||
if err != nil {
|
||||
// Clean up any created connections
|
||||
close(smbStorage.connPool)
|
||||
for c := range smbStorage.connPool {
|
||||
c.close()
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool")
|
||||
log.Warn().Err(err).Int("attempt", i).Msg("Failed to create initial SMB connection")
|
||||
continue
|
||||
}
|
||||
smbStorage.connPool <- conn
|
||||
storage.connPool <- conn
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := smbStorage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage")
|
||||
}
|
||||
log.Info().
|
||||
Str("host", cfg.Host).
|
||||
Int("port", cfg.Port).
|
||||
Str("share", cfg.Share).
|
||||
Str("path", cfg.Path).
|
||||
Int("pool_size", cfg.PoolSize).
|
||||
Msg("SMB storage initialized")
|
||||
|
||||
return smbStorage, nil
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new SMB connection
|
||||
func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) {
|
||||
conn, err := net.Dial("tcp", s.host)
|
||||
func (s *SMBStorage) createConnection() (*smbConnection, error) {
|
||||
// Connect to SMB server (use net.JoinHostPort for IPv6 compatibility)
|
||||
addr := net.JoinHostPort(s.config.Host, fmt.Sprintf("%d", s.config.Port))
|
||||
conn, err := net.DialTimeout("tcp", addr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to connect to SMB server: %w", err)
|
||||
}
|
||||
|
||||
dialer := &smb2.Dialer{
|
||||
// Create SMB dialer
|
||||
d := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: s.username,
|
||||
Password: s.password,
|
||||
User: s.config.Username,
|
||||
Password: s.config.Password,
|
||||
Domain: s.config.Domain,
|
||||
},
|
||||
}
|
||||
|
||||
session, err := dialer.Dial(conn)
|
||||
// Establish SMB session
|
||||
session, err := d.Dial(conn)
|
||||
if err != nil {
|
||||
conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, err
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("failed to establish SMB session: %w", err)
|
||||
}
|
||||
|
||||
share, err := session.Mount(s.share)
|
||||
// Mount share
|
||||
share, err := session.Mount(s.config.Share)
|
||||
if err != nil {
|
||||
_ = session.Logoff() // #nosec G104 -- SMB cleanup
|
||||
conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
return nil, err
|
||||
_ = session.Logoff()
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("failed to mount SMB share: %w", err)
|
||||
}
|
||||
|
||||
return &smbConnection{
|
||||
@@ -143,25 +135,34 @@ func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, erro
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getConnection gets a connection from the pool
|
||||
func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) {
|
||||
// getConnection gets a connection from the pool or creates a new one
|
||||
func (s *SMBStorage) getConnection() (*smbConnection, error) {
|
||||
select {
|
||||
case conn := <-s.connPool:
|
||||
// Check if connection is still valid (not older than 5 minutes idle)
|
||||
if time.Since(conn.lastUse) > 5*time.Minute {
|
||||
conn.close()
|
||||
return s.createConnection()
|
||||
}
|
||||
conn.lastUse = time.Now()
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection")
|
||||
default:
|
||||
// Pool is empty, create new connection
|
||||
return s.createConnection()
|
||||
}
|
||||
}
|
||||
|
||||
// returnConnection returns a connection to the pool
|
||||
func (s *SMBStorage) returnConnection(conn *smbConnection) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case s.connPool <- conn:
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool is full, close the connection
|
||||
// Pool is full, close connection
|
||||
conn.close()
|
||||
}
|
||||
}
|
||||
@@ -169,189 +170,161 @@ func (s *SMBStorage) returnConnection(conn *smbConnection) {
|
||||
// close closes an SMB connection
|
||||
func (c *smbConnection) close() {
|
||||
if c.share != nil {
|
||||
_ = c.share.Umount() // #nosec G104 -- SMB cleanup
|
||||
if err := c.share.Umount(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to unmount SMB share")
|
||||
}
|
||||
}
|
||||
if c.session != nil {
|
||||
_ = c.session.Logoff() // #nosec G104 -- SMB cleanup
|
||||
if err := c.session.Logoff(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to logoff SMB session")
|
||||
}
|
||||
}
|
||||
if c.conn != nil {
|
||||
c.conn.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to close SMB connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a file from SMB share
|
||||
// Get retrieves data from SMB share
|
||||
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
log.Debug().Str("key", path).Msg("Getting file from SMB")
|
||||
|
||||
// Open file
|
||||
file, err := conn.share.Open(path)
|
||||
if err != nil {
|
||||
s.returnConnection(conn)
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
|
||||
}
|
||||
|
||||
// Read entire file into memory and close SMB connection
|
||||
// This is necessary because we need to return the connection to the pool
|
||||
// Read entire file into memory (SMB files must be read completely before closing connection)
|
||||
data, err := io.ReadAll(file)
|
||||
file.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
s.returnConnection(conn)
|
||||
|
||||
if closeErr := file.Close(); closeErr != nil {
|
||||
log.Warn().Err(closeErr).Str("path", path).Msg("Failed to close SMB file after reading")
|
||||
}
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("smb", "get", "success")
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
// Return as ReadCloser
|
||||
return io.NopCloser(strings.NewReader(string(data))), nil
|
||||
}
|
||||
|
||||
// Put stores a file on SMB share
|
||||
// Put stores data on SMB share
|
||||
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
// Create directory structure
|
||||
if err := conn.share.MkdirAll(dir, 0755); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
log.Debug().Str("key", path).Msg("Putting file to SMB")
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(path)
|
||||
if err := s.ensureDir(conn, dir); err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
|
||||
}
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New() // #nosec G401 -- MD5 used for file integrity check, not cryptographic security
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
// Read data into buffer to check quota
|
||||
buf := new(strings.Builder)
|
||||
size, err := io.Copy(buf, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("smb", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
// Check quota if set
|
||||
if s.maxSizeBytes > 0 {
|
||||
currentUsage, err := s.calculateUsage(conn)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate current usage, skipping quota check")
|
||||
} else if currentUsage+size > s.maxSizeBytes {
|
||||
return errors.QuotaExceeded(s.maxSizeBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Create temp file for atomic write
|
||||
tempPath := path + ".tmp"
|
||||
file, err := conn.share.Create(tempPath)
|
||||
// Create/overwrite file
|
||||
file, err := conn.share.Create(path)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB file")
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write data
|
||||
_, err = io.Copy(file, bytes.NewReader(buf.Bytes()))
|
||||
file.Close() // #nosec G104 -- Cleanup, error not critical
|
||||
|
||||
_, err = file.Write([]byte(buf.String()))
|
||||
if err != nil {
|
||||
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := conn.share.Rename(tempPath, path); err != nil {
|
||||
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "put", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from SMB share
|
||||
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
// ensureDir ensures a directory exists on SMB share
|
||||
func (s *SMBStorage) ensureDir(conn *smbConnection, path string) error {
|
||||
if path == "" || path == "." || path == "/" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to stat the directory
|
||||
_, err := conn.share.Stat(path)
|
||||
if err == nil {
|
||||
return nil // Directory exists
|
||||
}
|
||||
|
||||
// Create parent directory first
|
||||
parent := filepath.Dir(path)
|
||||
if parent != path && parent != "." && parent != "/" {
|
||||
if err := s.ensureDir(conn, parent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create this directory
|
||||
err = conn.share.Mkdir(path, 0755)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes data from SMB share
|
||||
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
// Get size before deletion
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "delete", "not_found")
|
||||
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
log.Debug().Str("key", path).Msg("Deleting file from SMB")
|
||||
|
||||
size := info.Size()
|
||||
|
||||
if err := conn.share.Remove(path); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
err = conn.share.Remove(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "delete", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists on SMB share
|
||||
// Exists checks if data exists on SMB share
|
||||
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
@@ -368,57 +341,90 @@ func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix on SMB share
|
||||
// List returns a list of objects with the given prefix
|
||||
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
searchPath := s.keyToPath(prefix)
|
||||
basePath := s.keyToPath(prefix)
|
||||
|
||||
log.Debug().Str("prefix", basePath).Msg("Listing files in SMB")
|
||||
|
||||
var objects []storage.StorageObject
|
||||
|
||||
err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error {
|
||||
// Walk the directory tree
|
||||
err = s.walkPath(conn, basePath, func(path string, info os.FileInfo) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert path back to key
|
||||
key := s.pathToKey(path)
|
||||
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from SMB share
|
||||
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
// walkPath walks a directory tree on SMB share
|
||||
func (s *SMBStorage) walkPath(conn *smbConnection, root string, fn func(string, os.FileInfo) error) error {
|
||||
// Check if root exists
|
||||
info, err := conn.share.Stat(root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if os.IsNotExist(err) {
|
||||
return nil // Empty directory
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// If root is a file, process it directly
|
||||
if !info.IsDir() {
|
||||
return fn(root, info)
|
||||
}
|
||||
|
||||
// List directory contents
|
||||
entries, err := conn.share.ReadDir(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
fullPath := filepath.Join(root, entry.Name())
|
||||
|
||||
if err := fn(fullPath, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recurse into subdirectories
|
||||
if entry.IsDir() {
|
||||
if err := s.walkPath(conn, fullPath, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stat returns metadata about stored data
|
||||
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
@@ -427,7 +433,7 @@ func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
@@ -439,35 +445,35 @@ func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
// GetQuota returns current usage and quota information
|
||||
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
usage, err := s.calculateUsage(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
Used: usage,
|
||||
Limit: s.maxSizeBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks SMB health
|
||||
// Health checks if the SMB backend is healthy
|
||||
func (s *SMBStorage) Health(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
conn, err := s.getConnection()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed: cannot get connection")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
// Try to stat the base path
|
||||
path := s.keyToPath("")
|
||||
_, err = conn.share.Stat(path)
|
||||
_, err = conn.share.Stat(s.config.Path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
|
||||
}
|
||||
@@ -475,105 +481,62 @@ func (s *SMBStorage) Health(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
// Close closes the SMB storage backend
|
||||
func (s *SMBStorage) Close() error {
|
||||
close(s.connPool)
|
||||
|
||||
// Close all connections in pool
|
||||
for conn := range s.connPool {
|
||||
conn.close()
|
||||
}
|
||||
|
||||
log.Info().Msg("SMB storage closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// keyToPath converts a storage key to SMB path
|
||||
func (s *SMBStorage) keyToPath(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
key = filepath.Clean(key)
|
||||
// Normalize separators to backslash for SMB
|
||||
key = strings.ReplaceAll(key, "/", "\\")
|
||||
|
||||
// Remove path traversal attempts
|
||||
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = strings.TrimPrefix(key, "../")
|
||||
key = strings.TrimPrefix(key, "..\\")
|
||||
if s.config.Path == "" {
|
||||
return key
|
||||
}
|
||||
|
||||
key = filepath.Clean(key)
|
||||
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = ""
|
||||
}
|
||||
|
||||
if s.basePath != "" {
|
||||
return filepath.Join(s.basePath, key)
|
||||
}
|
||||
return key
|
||||
// Use backslash for SMB paths
|
||||
return s.config.Path + "\\" + key
|
||||
}
|
||||
|
||||
// pathToKey converts an SMB path back to a storage key
|
||||
// pathToKey converts an SMB path to storage key
|
||||
func (s *SMBStorage) pathToKey(path string) string {
|
||||
if s.basePath != "" {
|
||||
path = strings.TrimPrefix(path, s.basePath)
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
path = strings.TrimPrefix(path, "\\")
|
||||
// Remove base path
|
||||
if s.config.Path != "" {
|
||||
path = strings.TrimPrefix(path, s.config.Path+"\\")
|
||||
}
|
||||
return filepath.ToSlash(path)
|
||||
|
||||
// Convert backslashes to forward slashes for consistency
|
||||
return strings.ReplaceAll(path, "\\", "/")
|
||||
}
|
||||
|
||||
// walkPath recursively walks an SMB directory
|
||||
func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error {
|
||||
info, err := share.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
// calculateUsage calculates total storage usage
|
||||
func (s *SMBStorage) calculateUsage(conn *smbConnection) (int64, error) {
|
||||
var totalSize int64
|
||||
|
||||
basePath := s.config.Path
|
||||
if basePath == "" {
|
||||
basePath = "\\"
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fn(path, info)
|
||||
}
|
||||
|
||||
entries, err := share.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(path, entry.Name())
|
||||
if entry.IsDir() {
|
||||
if err := s.walkPath(share, entryPath, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := fn(entryPath, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateUsage calculates current SMB storage usage
|
||||
func (s *SMBStorage) calculateUsage(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
var total int64
|
||||
basePath := s.keyToPath("")
|
||||
|
||||
err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error {
|
||||
err := s.walkPath(conn, basePath, func(path string, info os.FileInfo) error {
|
||||
if !info.IsDir() {
|
||||
total += info.Size()
|
||||
totalSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to calculate usage: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("smb", total)
|
||||
return nil
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
package smb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SMBStorageTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func TestSMBStorageTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(SMBStorageTestSuite))
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestNewSMBStorage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMsg string
|
||||
config Config
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: Config{
|
||||
Host: "fileserver.example.com",
|
||||
Port: 445,
|
||||
Share: "gohoarder",
|
||||
Path: "packages",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
Domain: "CORP",
|
||||
MaxSizeBytes: 1024 * 1024,
|
||||
PoolSize: 5,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing host",
|
||||
config: Config{
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing share",
|
||||
config: Config{
|
||||
Host: "fileserver.example.com",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "share is required",
|
||||
},
|
||||
{
|
||||
name: "default port",
|
||||
config: Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "default pool size",
|
||||
config: Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage, err := New(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
if tt.errorMsg != "" {
|
||||
s.Contains(err.Error(), tt.errorMsg)
|
||||
}
|
||||
s.Nil(storage)
|
||||
} else {
|
||||
// Note: This will fail in actual execution since we can't connect to a real SMB server
|
||||
// But it tests the validation logic
|
||||
if err != nil {
|
||||
// Connection errors are expected in unit tests
|
||||
s.Contains(err.Error(), "Failed to create initial SMB connection")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestKeyToPath() {
|
||||
tests := []struct {
|
||||
name string
|
||||
basePath string
|
||||
key string
|
||||
expectedWin string // Expected Windows-style path
|
||||
}{
|
||||
{
|
||||
name: "simple key with base path",
|
||||
basePath: "packages",
|
||||
key: "test/file.txt",
|
||||
expectedWin: "packages\\test\\file.txt",
|
||||
},
|
||||
{
|
||||
name: "simple key without base path",
|
||||
basePath: "",
|
||||
key: "test/file.txt",
|
||||
expectedWin: "test\\file.txt",
|
||||
},
|
||||
{
|
||||
name: "nested key",
|
||||
basePath: "cache",
|
||||
key: "deep/nested/path/file.txt",
|
||||
expectedWin: "cache\\deep\\nested\\path\\file.txt",
|
||||
},
|
||||
{
|
||||
name: "key with backslashes",
|
||||
basePath: "packages",
|
||||
key: "test\\file.txt",
|
||||
expectedWin: "packages\\test\\file.txt",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage := &SMBStorage{
|
||||
config: Config{
|
||||
Path: tt.basePath,
|
||||
},
|
||||
}
|
||||
|
||||
result := storage.keyToPath(tt.key)
|
||||
s.Equal(tt.expectedWin, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestPathToKey() {
|
||||
tests := []struct {
|
||||
name string
|
||||
basePath string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "windows path with base path",
|
||||
basePath: "packages",
|
||||
path: "packages\\test\\file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "windows path without base path",
|
||||
basePath: "",
|
||||
path: "test\\file.txt",
|
||||
expected: "test/file.txt",
|
||||
},
|
||||
{
|
||||
name: "nested windows path",
|
||||
basePath: "cache",
|
||||
path: "cache\\deep\\nested\\path\\file.txt",
|
||||
expected: "deep/nested/path/file.txt",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
storage := &SMBStorage{
|
||||
config: Config{
|
||||
Path: tt.basePath,
|
||||
},
|
||||
}
|
||||
|
||||
result := storage.pathToKey(tt.path)
|
||||
s.Equal(tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestConfigDefaults() {
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
// This will fail to connect, but we can verify the config validation
|
||||
_, err := New(config)
|
||||
|
||||
// We expect a connection error, not a validation error
|
||||
if err != nil {
|
||||
s.NotContains(err.Error(), "host is required")
|
||||
s.NotContains(err.Error(), "share is required")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestPathNormalization() {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputPath string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "path with trailing slash",
|
||||
inputPath: "packages/",
|
||||
expectedPath: "packages",
|
||||
},
|
||||
{
|
||||
name: "path with trailing backslash",
|
||||
inputPath: "packages\\",
|
||||
expectedPath: "packages",
|
||||
},
|
||||
{
|
||||
name: "path without trailing slash",
|
||||
inputPath: "packages",
|
||||
expectedPath: "packages",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
inputPath: "",
|
||||
expectedPath: "",
|
||||
},
|
||||
{
|
||||
name: "nested path with trailing slash",
|
||||
inputPath: "cache/packages/",
|
||||
expectedPath: "cache/packages",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Path: tt.inputPath,
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
// This will fail to connect, but we can check the config
|
||||
storage, _ := New(config)
|
||||
if storage != nil {
|
||||
s.Equal(tt.expectedPath, storage.config.Path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestPoolSizeDefaults() {
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
storage, _ := New(config)
|
||||
if storage != nil {
|
||||
s.Equal(5, storage.poolSize) // Default pool size
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestPortDefaults() {
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
storage, _ := New(config)
|
||||
if storage != nil {
|
||||
s.Equal(445, storage.config.Port) // Default SMB port
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestClose() {
|
||||
// Create a storage instance (will fail to connect but that's ok)
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
storage, _ := New(config)
|
||||
if storage != nil {
|
||||
// Close should not panic
|
||||
err := storage.Close()
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestConnectionPoolChannel() {
|
||||
config := Config{
|
||||
Host: "fileserver.example.com",
|
||||
Share: "gohoarder",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
PoolSize: 10,
|
||||
}
|
||||
|
||||
storage, _ := New(config)
|
||||
if storage != nil {
|
||||
// Verify pool channel capacity
|
||||
s.NotNil(storage.connPool)
|
||||
s.Equal(10, cap(storage.connPool))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SMBStorageTestSuite) TestSMBConnectionStruct() {
|
||||
// Verify smbConnection structure exists and has required fields
|
||||
conn := &smbConnection{}
|
||||
s.NotNil(conn)
|
||||
}
|
||||
@@ -56,8 +56,8 @@ func TestNew(t *testing.T) {
|
||||
func TestString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid UUID
|
||||
expected string
|
||||
uuid UUID
|
||||
}{
|
||||
{
|
||||
name: "zero UUID",
|
||||
|
||||
+1
-1
@@ -14,9 +14,9 @@ import (
|
||||
|
||||
// GitFetcher handles git repository operations
|
||||
type GitFetcher struct {
|
||||
credStore *CredentialStore
|
||||
workDir string
|
||||
timeout time.Duration
|
||||
credStore *CredentialStore
|
||||
}
|
||||
|
||||
// NewGitFetcher creates a new git fetcher
|
||||
|
||||
+1
-1
@@ -27,8 +27,8 @@ func NewModuleBuilder() *ModuleBuilder {
|
||||
|
||||
// ModuleInfo represents Go module version metadata (.info file)
|
||||
type ModuleInfo struct {
|
||||
Version string `json:"Version"`
|
||||
Time time.Time `json:"Time"`
|
||||
Version string `json:"Version"`
|
||||
}
|
||||
|
||||
// BuildModuleZip creates a Go module zip from source directory
|
||||
|
||||
+4
-11
@@ -25,9 +25,9 @@ const (
|
||||
|
||||
// Event represents a WebSocket event message
|
||||
type Event struct {
|
||||
Type EventType `json:"type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Type EventType `json:"type"`
|
||||
}
|
||||
|
||||
// Client represents a WebSocket client connection
|
||||
@@ -45,15 +45,15 @@ type Server struct {
|
||||
broadcast chan Event
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
mu sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config holds WebSocket server configuration
|
||||
type Config struct {
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
}
|
||||
|
||||
// NewServer creates a new WebSocket server
|
||||
@@ -307,8 +307,8 @@ func (c *Client) writePump() {
|
||||
// handleMessage processes incoming client messages
|
||||
func (c *Client) handleMessage(message []byte) {
|
||||
var msg struct {
|
||||
Action string `json:"action"`
|
||||
Data interface{} `json:"data"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
@@ -379,10 +379,3 @@ func (c *Client) sendPong() {
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnectedClients returns the number of connected clients
|
||||
func (s *Server) GetConnectedClients() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.clients)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user