diff --git a/pkg/analytics/analytics.go b/pkg/analytics/analytics.go index bf9c5ad..379f04c 100644 --- a/pkg/analytics/analytics.go +++ b/pkg/analytics/analytics.go @@ -10,23 +10,23 @@ import ( // PackageDownload represents a package download event type PackageDownload struct { + Timestamp time.Time Registry string Name string Version string - Timestamp time.Time - BytesSize int64 ClientIP string UserAgent string + BytesSize int64 } // PackageStats holds statistics for a package type PackageStats struct { + LastDownload time.Time + FirstSeen time.Time Registry string Name string TotalDownloads int64 UniqueVersions int - LastDownload time.Time - FirstSeen time.Time BytesServed int64 } @@ -48,13 +48,13 @@ type PopularPackage struct { // Engine tracks and analyzes package downloads type Engine struct { - downloads []PackageDownload - downloadsMu sync.RWMutex - stats map[string]*PackageStats // key: registry:name - statsMu sync.RWMutex - maxEvents int + stats map[string]*PackageStats flushTicker *time.Ticker stopChan chan struct{} + downloads []PackageDownload + maxEvents int + downloadsMu sync.RWMutex + statsMu sync.RWMutex } // Config holds analytics engine configuration diff --git a/pkg/app/app.go b/pkg/app/app.go index 81dc777..6f06504 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -17,7 +17,6 @@ import ( "github.com/lukaszraczylo/gohoarder/pkg/cdn" "github.com/lukaszraczylo/gohoarder/pkg/config" "github.com/lukaszraczylo/gohoarder/pkg/health" - "github.com/lukaszraczylo/gohoarder/pkg/lock" "github.com/lukaszraczylo/gohoarder/pkg/metadata" metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file" metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite" @@ -30,6 +29,8 @@ import ( "github.com/lukaszraczylo/gohoarder/pkg/scanner" "github.com/lukaszraczylo/gohoarder/pkg/storage" "github.com/lukaszraczylo/gohoarder/pkg/storage/filesystem" + "github.com/lukaszraczylo/gohoarder/pkg/storage/s3" + "github.com/lukaszraczylo/gohoarder/pkg/storage/smb" "github.com/lukaszraczylo/gohoarder/pkg/vcs" "github.com/lukaszraczylo/gohoarder/pkg/websocket" "github.com/rs/zerolog/log" @@ -50,7 +51,6 @@ type App struct { analyticsEngine *analytics.Engine wsServer *websocket.Server prewarmWorker *prewarming.Worker - lockManager *lock.Manager cdnMiddleware *cdn.Middleware } @@ -82,7 +82,33 @@ func (a *App) initializeComponents() error { switch a.config.Storage.Backend { case "filesystem": a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes) + case "s3": + a.storage, err = s3.New(s3.Config{ + Region: a.config.Storage.S3.Region, + Bucket: a.config.Storage.S3.Bucket, + Prefix: a.config.Storage.S3.Prefix, + AccessKeyID: a.config.Storage.S3.AccessKeyID, + SecretAccessKey: a.config.Storage.S3.SecretAccessKey, + Endpoint: a.config.Storage.S3.Endpoint, + ForcePathStyle: a.config.Storage.S3.ForcePathStyle, + MaxSizeBytes: a.config.Cache.MaxSizeBytes, + }) + case "smb": + a.storage, err = smb.New(smb.Config{ + Host: a.config.Storage.SMB.Host, + Port: 445, // Default SMB port + Share: a.config.Storage.SMB.Share, + Path: a.config.Storage.Path, + Username: a.config.Storage.SMB.Username, + Password: a.config.Storage.SMB.Password, + Domain: a.config.Storage.SMB.Domain, + MaxSizeBytes: a.config.Cache.MaxSizeBytes, + PoolSize: 5, // Default connection pool size + }) default: + log.Warn(). + Str("backend", a.config.Storage.Backend). + Msg("Unknown storage backend, defaulting to filesystem") a.storage, err = filesystem.New(a.config.Storage.Path, a.config.Cache.MaxSizeBytes) } if err != nil { @@ -116,9 +142,16 @@ func (a *App) initializeComponents() error { return fmt.Errorf("failed to initialize scanner: %w", err) } - // Initialize cache manager with scanner + // Initialize analytics engine first (needed by cache) + log.Info().Msg("Initializing analytics engine") + a.analyticsEngine = analytics.NewEngine(analytics.Config{ + MaxEvents: 10000, + FlushInterval: 5 * time.Minute, + }) + + // Initialize cache manager with scanner and analytics log.Info().Msg("Initializing cache manager") - a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, cache.Config{ + a.cache, err = cache.New(a.storage, a.metadata, a.scanManager, a.analyticsEngine, cache.Config{ DefaultTTL: a.config.Cache.DefaultTTL, CleanupInterval: 5 * time.Minute, }) @@ -153,13 +186,6 @@ func (a *App) initializeComponents() error { a.rescanWorker = scanner.NewRescanWorker(a.scanManager, a.metadata, a.storage, a.config.Security.RescanInterval) } - // Initialize analytics engine - log.Info().Msg("Initializing analytics engine") - a.analyticsEngine = analytics.NewEngine(analytics.Config{ - MaxEvents: 10000, - FlushInterval: 5 * time.Minute, - }) - // Initialize WebSocket server log.Info().Msg("Initializing WebSocket server") a.wsServer = websocket.NewServer(websocket.Config{ @@ -248,9 +274,28 @@ func (a *App) setupServer() error { a.app.Get("/api/stats/timeseries", a.handleTimeSeriesStats) a.app.Get("/api/info", a.handleInfo) + // Analytics endpoints + a.app.Get("/api/analytics/top", a.handleAnalyticsTopPackages) + a.app.Get("/api/analytics/trending", a.handleAnalyticsTrendingPackages) + a.app.Get("/api/analytics/trends", a.handleAnalyticsTrends) + a.app.Get("/api/analytics/total", a.handleAnalyticsTotalStats) + a.app.Get("/api/analytics/registry/:registry", a.handleAnalyticsRegistryStats) + a.app.Get("/api/analytics/package/:registry/:name", a.handleAnalyticsPackageStats) + a.app.Get("/api/analytics/search", a.handleAnalyticsSearch) + // Admin endpoints (bypass management) a.app.All("/api/admin/bypasses/:id?", a.requireAdmin, a.handleAdminBypasses) + // Admin endpoints (pre-warming) + a.app.Get("/api/admin/prewarming/status", a.requireAdmin, a.handlePrewarmingStatus) + a.app.Post("/api/admin/prewarming/trigger", a.requireAdmin, a.handlePrewarmingTrigger) + a.app.Post("/api/admin/prewarming/package", a.requireAdmin, a.handlePrewarmingPackage) + + // Admin endpoints (API key management) + a.app.Post("/api/admin/keys", a.requireAdmin, a.handleGenerateAPIKey) + a.app.Get("/api/admin/keys", a.requireAdmin, a.handleListAPIKeys) + a.app.Delete("/api/admin/keys/:key_id", a.requireAdmin, a.handleRevokeAPIKey) + // Proxy handlers (adapted from net/http) // Load git credentials if configured var credStore *vcs.CredentialStore @@ -270,22 +315,28 @@ func (a *App) setupServer() error { } } + // Go proxy with CDN caching goProxyHandler := goproxy.New(a.cache, a.networkClient, goproxy.Config{ Upstream: "https://proxy.golang.org", SumDBURL: "https://sum.golang.org", CredStore: credStore, }) - a.app.All("/go/*", adaptor.HTTPHandler(http.StripPrefix("/go", goProxyHandler))) + goProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/go", goProxyHandler)) + a.app.All("/go/*", adaptor.HTTPHandler(goProxyWithCDN)) + // NPM proxy with CDN caching npmProxyHandler := npm.New(a.cache, a.networkClient, npm.Config{ Upstream: "https://registry.npmjs.org", }) - a.app.All("/npm/*", adaptor.HTTPHandler(http.StripPrefix("/npm", npmProxyHandler))) + npmProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/npm", npmProxyHandler)) + a.app.All("/npm/*", adaptor.HTTPHandler(npmProxyWithCDN)) + // PyPI proxy with CDN caching pypiProxyHandler := pypi.New(a.cache, a.networkClient, pypi.Config{ Upstream: "https://pypi.org/simple", }) - a.app.All("/pypi/*", adaptor.HTTPHandler(http.StripPrefix("/pypi", pypiProxyHandler))) + pypiProxyWithCDN := a.cdnMiddleware.Handler(http.StripPrefix("/pypi", pypiProxyHandler)) + a.app.All("/pypi/*", adaptor.HTTPHandler(pypiProxyWithCDN)) // Serve frontend static files frontendDir := "frontend/dist" @@ -397,13 +448,6 @@ func (a *App) Shutdown() error { log.Error().Err(err).Msg("Error closing metadata store") } - // Close lock manager if initialized - if a.lockManager != nil { - if err := a.lockManager.Close(); err != nil { - log.Error().Err(err).Msg("Error closing lock manager") - } - } - log.Info().Msg("Shutdown complete") return nil } diff --git a/pkg/app/handlers.go b/pkg/app/handlers.go index ac9d55c..3580a04 100644 --- a/pkg/app/handlers.go +++ b/pkg/app/handlers.go @@ -500,11 +500,10 @@ func (a *App) handleInfo(c *fiber.Ctx) error { "max_cache_size": a.config.Cache.MaxSizeBytes, }, "features": map[string]bool{ - "distributed_locking": a.lockManager != nil, - "security_scanning": a.config.Security.Enabled, - "pre_warming": a.prewarmWorker != nil, - "websockets": true, - "analytics": true, + "security_scanning": a.config.Security.Enabled, + "pre_warming": a.prewarmWorker != nil, + "websockets": true, + "analytics": true, }, } diff --git a/pkg/app/handlers_admin.go b/pkg/app/handlers_admin.go index 8cd6b14..f4e4c7c 100644 --- a/pkg/app/handlers_admin.go +++ b/pkg/app/handlers_admin.go @@ -110,13 +110,13 @@ func (a *App) handleListBypasses(c *fiber.Ctx) error { // CreateBypassRequest represents the request body for creating a bypass type CreateBypassRequest struct { - Type metadata.BypassType `json:"type"` // "cve" or "package" - Target string `json:"target"` // CVE ID or package name - Reason string `json:"reason"` // Why this bypass is needed - CreatedBy string `json:"created_by"` // Admin username - ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration - AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package - NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired + Type metadata.BypassType `json:"type"` + Target string `json:"target"` + Reason string `json:"reason"` + CreatedBy string `json:"created_by"` + AppliesTo string `json:"applies_to,omitempty"` + ExpiresInHours int `json:"expires_in_hours"` + NotifyOnExpiry bool `json:"notify_on_expiry"` } // handleCreateBypass creates a new CVE bypass diff --git a/pkg/app/handlers_analytics.go b/pkg/app/handlers_analytics.go new file mode 100644 index 0000000..dbfbf02 --- /dev/null +++ b/pkg/app/handlers_analytics.go @@ -0,0 +1,149 @@ +package app + +import ( + "strconv" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// handleAnalyticsTopPackages returns the most downloaded packages +func (a *App) handleAnalyticsTopPackages(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + // Get limit from query params (default: 10) + limit := 10 + if limitStr := c.Query("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + packages := a.analyticsEngine.GetTopPackages(limit) + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "packages": packages, + "total": len(packages), + }) +} + +// handleAnalyticsTrendingPackages returns trending packages +func (a *App) handleAnalyticsTrendingPackages(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + // Get limit from query params (default: 10) + limit := 10 + if limitStr := c.Query("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + packages := a.analyticsEngine.GetTrendingPackages(limit) + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "packages": packages, + "total": len(packages), + }) +} + +// handleAnalyticsTrends returns download trends over time +func (a *App) handleAnalyticsTrends(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + trends := a.analyticsEngine.GetTrends() + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "trends": trends, + }) +} + +// handleAnalyticsTotalStats returns overall statistics +func (a *App) handleAnalyticsTotalStats(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + stats := a.analyticsEngine.GetTotalStats() + + return c.Status(fiber.StatusOK).JSON(stats) +} + +// handleAnalyticsRegistryStats returns per-registry statistics +func (a *App) handleAnalyticsRegistryStats(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + registry := c.Params("registry") + if registry == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "registry parameter is required", + }) + } + + stats := a.analyticsEngine.GetRegistryStats(registry) + + return c.Status(fiber.StatusOK).JSON(stats) +} + +// handleAnalyticsPackageStats returns statistics for a specific package +func (a *App) handleAnalyticsPackageStats(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + registry := c.Params("registry") + name := c.Params("name") + + if registry == "" || name == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "registry and name parameters are required", + }) + } + + stats, exists := a.analyticsEngine.GetPackageStats(registry, name) + if !exists { + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{ + "error": "package not found in analytics", + }) + } + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "package": stats, + }) +} + +// handleAnalyticsSearch searches for packages matching a query +func (a *App) handleAnalyticsSearch(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + + query := c.Query("q") + if query == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "query parameter 'q' is required", + }) + } + + // Get limit from query params (default: 20) + limit := 20 + if limitStr := c.Query("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + results := a.analyticsEngine.SearchPackages(query, limit) + + log.Debug(). + Str("query", query). + Int("results", len(results)). + Msg("Analytics search completed") + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "results": results, + "total": len(results), + "query": query, + }) +} diff --git a/pkg/app/handlers_analytics_test.go b/pkg/app/handlers_analytics_test.go new file mode 100644 index 0000000..4bb0f4a --- /dev/null +++ b/pkg/app/handlers_analytics_test.go @@ -0,0 +1,298 @@ +package app + +import ( + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/lukaszraczylo/gohoarder/pkg/analytics" + "github.com/stretchr/testify/suite" +) + +type AnalyticsHandlersTestSuite struct { + suite.Suite + app *fiber.App + appInst *App + engine *analytics.Engine +} + +func (s *AnalyticsHandlersTestSuite) SetupTest() { + // Create analytics engine + s.engine = analytics.NewEngine(analytics.Config{ + MaxEvents: 10000, + FlushInterval: 5 * time.Minute, + }) + + // Seed some test data + s.engine.TrackDownload(analytics.PackageDownload{ + Registry: "npm", + Name: "lodash", + Version: "4.17.21", + Timestamp: time.Now(), + BytesSize: 1024, + }) + s.engine.TrackDownload(analytics.PackageDownload{ + Registry: "npm", + Name: "react", + Version: "18.0.0", + Timestamp: time.Now(), + BytesSize: 2048, + }) + s.engine.TrackDownload(analytics.PackageDownload{ + Registry: "pypi", + Name: "requests", + Version: "2.28.0", + Timestamp: time.Now(), + BytesSize: 512, + }) + + // Create app instance + s.appInst = &App{ + analyticsEngine: s.engine, + } + + // Create Fiber app + s.app = fiber.New() + + // Register routes + s.app.Get("/api/analytics/top", s.appInst.handleAnalyticsTopPackages) + s.app.Get("/api/analytics/trending", s.appInst.handleAnalyticsTrendingPackages) + s.app.Get("/api/analytics/trends", s.appInst.handleAnalyticsTrends) + s.app.Get("/api/analytics/total", s.appInst.handleAnalyticsTotalStats) + s.app.Get("/api/analytics/registry/:registry", s.appInst.handleAnalyticsRegistryStats) + s.app.Get("/api/analytics/package/:registry/:name", s.appInst.handleAnalyticsPackageStats) + s.app.Get("/api/analytics/search", s.appInst.handleAnalyticsSearch) +} + +func (s *AnalyticsHandlersTestSuite) TearDownTest() { + if s.engine != nil { + s.engine.Close() + } +} + +func TestAnalyticsHandlersTestSuite(t *testing.T) { + suite.Run(t, new(AnalyticsHandlersTestSuite)) +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTopPackages() { + tests := []struct { + name string + queryParams string + expectedStatus int + expectError bool + }{ + { + name: "get top packages default", + queryParams: "", + expectedStatus: 200, + expectError: false, + }, + { + name: "get top packages with limit", + queryParams: "?limit=5", + expectedStatus: 200, + expectError: false, + }, + { + name: "get top packages with registry filter", + queryParams: "?registry=npm", + expectedStatus: 200, + expectError: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("GET", "/api/analytics/top"+tt.queryParams, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + + if !tt.expectError { + var result struct { + Packages []analytics.PackageStats `json:"packages"` + Total int `json:"total"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + } + }) + } +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTrendingPackages() { + req := httptest.NewRequest("GET", "/api/analytics/trending", nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(200, resp.StatusCode) + + var result struct { + Packages []analytics.PackageStats `json:"packages"` + Total int `json:"total"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTrends() { + tests := []struct { + name string + queryParams string + expectedStatus int + }{ + { + name: "get trends default timeframe", + queryParams: "", + expectedStatus: 200, + }, + { + name: "get trends with registry filter", + queryParams: "?registry=npm", + expectedStatus: 200, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("GET", "/api/analytics/trends"+tt.queryParams, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + }) + } +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsTotalStats() { + req := httptest.NewRequest("GET", "/api/analytics/total", nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(200, resp.StatusCode) + + var result struct { + TotalDownloads int64 `json:"total_downloads"` + TotalBytes int64 `json:"total_bytes"` + UniquePackages int `json:"unique_packages"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.Greater(result.TotalDownloads, int64(0)) +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsRegistryStats() { + tests := []struct { + name string + registry string + expectedStatus int + }{ + { + name: "npm registry stats", + registry: "npm", + expectedStatus: 200, + }, + { + name: "pypi registry stats", + registry: "pypi", + expectedStatus: 200, + }, + { + name: "go registry stats", + registry: "go", + expectedStatus: 200, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("GET", "/api/analytics/registry/"+tt.registry, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + }) + } +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsPackageStats() { + tests := []struct { + name string + registry string + packageName string + expectedStatus int + }{ + { + name: "lodash package stats", + registry: "npm", + packageName: "lodash", + expectedStatus: 200, + }, + { + name: "react package stats", + registry: "npm", + packageName: "react", + expectedStatus: 200, + }, + { + name: "requests package stats", + registry: "pypi", + packageName: "requests", + expectedStatus: 200, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("GET", "/api/analytics/package/"+tt.registry+"/"+tt.packageName, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + }) + } +} + +func (s *AnalyticsHandlersTestSuite) TestHandleAnalyticsSearch() { + tests := []struct { + name string + queryParams string + expectedStatus int + expectError bool + }{ + { + name: "search for lodash", + queryParams: "?q=lodash", + expectedStatus: 200, + expectError: false, + }, + { + name: "search for react", + queryParams: "?q=react", + expectedStatus: 200, + expectError: false, + }, + { + name: "search with no query", + queryParams: "", + expectedStatus: 400, // Query parameter is required + expectError: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("GET", "/api/analytics/search"+tt.queryParams, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + + if !tt.expectError { + var result struct { + Results []analytics.PackageStats `json:"results"` + Total int `json:"total"` + Query string `json:"query"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + } + }) + } +} diff --git a/pkg/app/handlers_auth.go b/pkg/app/handlers_auth.go new file mode 100644 index 0000000..38eb61f --- /dev/null +++ b/pkg/app/handlers_auth.go @@ -0,0 +1,135 @@ +package app + +import ( + "time" + + "github.com/gofiber/fiber/v2" + "github.com/lukaszraczylo/gohoarder/pkg/auth" + "github.com/rs/zerolog/log" +) + +// GenerateAPIKeyRequest represents a request to generate a new API key +type GenerateAPIKeyRequest struct { + ExpiresInMin *int `json:"expires_in_min"` + Name string `json:"name"` + Role string `json:"role"` +} + +// handleGenerateAPIKey generates a new API key +func (a *App) handleGenerateAPIKey(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + var req GenerateAPIKeyRequest + if err := c.BodyParser(&req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid JSON in request body", + }) + } + + // Validate request + if req.Name == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "name is required", + }) + } + + // Parse role (default to readonly if not specified) + var role auth.Role + switch req.Role { + case "admin": + role = auth.RoleAdmin + case "readwrite": + role = auth.RoleReadWrite + case "readonly", "": + role = auth.RoleReadOnly + default: + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid role, must be 'admin', 'readwrite', or 'readonly'", + }) + } + + // Calculate expiration + var expiresIn *time.Duration + if req.ExpiresInMin != nil { + duration := time.Duration(*req.ExpiresInMin) * time.Minute + expiresIn = &duration + } + + // Generate key + apiKey, rawKey, err := a.authManager.GenerateAPIKey(req.Name, role, expiresIn) + if err != nil { + log.Error().Err(err).Str("name", req.Name).Msg("Failed to generate API key") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": "failed to generate API key", + }) + } + + log.Info(). + Str("key_id", apiKey.ID). + Str("name", apiKey.Name). + Str("role", string(apiKey.Role)). + Msg("API key generated") + + // Return the key info and raw key (only time it's shown!) + return c.Status(fiber.StatusCreated).JSON(fiber.Map{ + "key": rawKey, // IMPORTANT: This is the only time the raw key is shown + "key_id": apiKey.ID, + "name": apiKey.Name, + "role": apiKey.Role, + "expires": apiKey.ExpiresAt, + "message": "Save this key now! It will not be shown again.", + }) +} + +// handleListAPIKeys lists all API keys +func (a *App) handleListAPIKeys(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + keys := a.authManager.ListAPIKeys() + + // Convert to response format (excluding hashed keys) + response := make([]fiber.Map, len(keys)) + for i, key := range keys { + response[i] = fiber.Map{ + "id": key.ID, + "name": key.Name, + "role": key.Role, + "created_at": key.CreatedAt, + "expires_at": key.ExpiresAt, + "last_used_at": key.LastUsedAt, + "permissions": key.Permissions, + } + } + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "keys": response, + "total": len(response), + }) +} + +// handleRevokeAPIKey revokes an API key +func (a *App) handleRevokeAPIKey(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + keyID := c.Params("key_id") + if keyID == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "key_id parameter is required", + }) + } + + err := a.authManager.RevokeAPIKey(keyID) + if err != nil { + log.Warn().Err(err).Str("key_id", keyID).Msg("Failed to revoke API key") + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{ + "error": "API key not found", + }) + } + + log.Info().Str("key_id", keyID).Msg("API key revoked") + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "message": "API key revoked successfully", + "key_id": keyID, + }) +} diff --git a/pkg/app/handlers_auth_test.go b/pkg/app/handlers_auth_test.go new file mode 100644 index 0000000..18686a0 --- /dev/null +++ b/pkg/app/handlers_auth_test.go @@ -0,0 +1,285 @@ +package app + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/lukaszraczylo/gohoarder/pkg/auth" + "github.com/stretchr/testify/suite" +) + +type AuthHandlersTestSuite struct { + suite.Suite + app *fiber.App + appInst *App + authManager *auth.Manager +} + +func (s *AuthHandlersTestSuite) SetupTest() { + // Create auth manager + s.authManager = auth.New() + + // Create app instance + s.appInst = &App{ + authManager: s.authManager, + } + + // Create Fiber app + s.app = fiber.New() + + // Register routes + s.app.Post("/api/admin/keys", s.appInst.handleGenerateAPIKey) + s.app.Get("/api/admin/keys", s.appInst.handleListAPIKeys) + s.app.Delete("/api/admin/keys/:key_id", s.appInst.handleRevokeAPIKey) +} + +func TestAuthHandlersTestSuite(t *testing.T) { + suite.Run(t, new(AuthHandlersTestSuite)) +} + +func (s *AuthHandlersTestSuite) TestHandleGenerateAPIKey() { + tests := []struct { + requestBody map[string]string + name string + expectedStatus int + expectedRole string + expectKey bool + }{ + { + name: "generate read-only key", + requestBody: map[string]string{ + "role": "readonly", + "name": "test-readonly-key", + }, + expectedStatus: 201, + expectedRole: "readonly", + expectKey: true, + }, + { + name: "generate read-write key", + requestBody: map[string]string{ + "role": "readwrite", + "name": "test-readwrite-key", + }, + expectedStatus: 201, + expectedRole: "readwrite", + expectKey: true, + }, + { + name: "generate admin key", + requestBody: map[string]string{ + "role": "admin", + "name": "test-admin-key", + }, + expectedStatus: 201, + expectedRole: "admin", + expectKey: true, + }, + { + name: "invalid role", + requestBody: map[string]string{ + "role": "invalid-role", + "name": "test-key", + }, + expectedStatus: 400, + expectKey: false, + }, + { + name: "missing role defaults to readonly", + requestBody: map[string]string{ + "name": "test-key-default-role", + }, + expectedStatus: 201, + expectedRole: "readonly", // Role defaults to readonly when not specified + expectKey: true, + }, + { + name: "missing name", + requestBody: map[string]string{ + "role": "read-only", + }, + expectedStatus: 400, + expectKey: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + bodyBytes, err := json.Marshal(tt.requestBody) + s.Require().NoError(err) + + req := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + + if tt.expectKey { + var result struct { + Key string `json:"key"` + KeyID string `json:"key_id"` + Role string `json:"role"` + Name string `json:"name"` + Message string `json:"message"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.NotEmpty(result.Key) + s.NotEmpty(result.KeyID) + s.Equal(tt.expectedRole, result.Role) + s.Equal(tt.requestBody["name"], result.Name) + } + }) + } +} + +func (s *AuthHandlersTestSuite) TestHandleListAPIKeys() { + // Generate some test keys first + s.authManager.GenerateAPIKey("test-key-1", auth.RoleReadOnly, nil) + s.authManager.GenerateAPIKey("test-key-2", auth.RoleReadWrite, nil) + s.authManager.GenerateAPIKey("test-key-3", auth.RoleAdmin, nil) + + req := httptest.NewRequest("GET", "/api/admin/keys", nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(200, resp.StatusCode) + + var result struct { + Keys []map[string]interface{} `json:"keys"` + Total int `json:"total"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.GreaterOrEqual(result.Total, 3) + + // Verify keys don't include the actual key value + for _, key := range result.Keys { + s.NotEmpty(key["id"]) + s.NotEmpty(key["role"]) + s.NotEmpty(key["name"]) + s.NotEmpty(key["created_at"]) + } +} + +func (s *AuthHandlersTestSuite) TestHandleRevokeAPIKey() { + // Generate a test key + keyInfo, _, _ := s.authManager.GenerateAPIKey("test-revoke-key", auth.RoleReadOnly, nil) + + tests := []struct { + name string + keyID string + expectedStatus int + }{ + { + name: "revoke existing key", + keyID: keyInfo.ID, + expectedStatus: 200, + }, + { + name: "revoke non-existent key", + keyID: "non-existent-key-id", + expectedStatus: 404, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + req := httptest.NewRequest("DELETE", "/api/admin/keys/"+tt.keyID, nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + + if tt.expectedStatus == 200 { + var result map[string]string + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.Contains(result, "message") + } + }) + } +} + +func (s *AuthHandlersTestSuite) TestHandleGenerateAPIKeyInvalidJSON() { + req := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader([]byte("invalid json"))) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(400, resp.StatusCode) +} + +func (s *AuthHandlersTestSuite) TestGenerateAndRevokeKeyFlow() { + // Generate a key + bodyBytes, _ := json.Marshal(map[string]string{ + "role": "readonly", + "name": "integration-test-key", + }) + + req1 := httptest.NewRequest("POST", "/api/admin/keys", bytes.NewReader(bodyBytes)) + req1.Header.Set("Content-Type", "application/json") + resp1, err := s.app.Test(req1) + s.Require().NoError(err) + s.Equal(201, resp1.StatusCode) + + var createResult struct { + Key string `json:"key"` + KeyID string `json:"key_id"` + } + err = json.NewDecoder(resp1.Body).Decode(&createResult) + s.Require().NoError(err) + keyID := createResult.KeyID + + // List keys - should include our new key + req2 := httptest.NewRequest("GET", "/api/admin/keys", nil) + resp2, err := s.app.Test(req2) + s.Require().NoError(err) + s.Equal(200, resp2.StatusCode) + + var listResult struct { + Keys []map[string]interface{} `json:"keys"` + Total int `json:"total"` + } + err = json.NewDecoder(resp2.Body).Decode(&listResult) + s.Require().NoError(err) + + found := false + for _, key := range listResult.Keys { + if key["id"].(string) == keyID { + found = true + break + } + } + s.True(found, "newly created key should be in the list") + + // Revoke the key + req3 := httptest.NewRequest("DELETE", "/api/admin/keys/"+keyID, nil) + resp3, err := s.app.Test(req3) + s.Require().NoError(err) + s.Equal(200, resp3.StatusCode) + + // List keys again - should not include the revoked key + req4 := httptest.NewRequest("GET", "/api/admin/keys", nil) + resp4, err := s.app.Test(req4) + s.Require().NoError(err) + s.Equal(200, resp4.StatusCode) + + var listResult2 struct { + Keys []map[string]interface{} `json:"keys"` + Total int `json:"total"` + } + err = json.NewDecoder(resp4.Body).Decode(&listResult2) + s.Require().NoError(err) + + found = false + for _, key := range listResult2.Keys { + if key["id"].(string) == keyID { + found = true + break + } + } + s.False(found, "revoked key should not be in the list") +} diff --git a/pkg/app/handlers_prewarming.go b/pkg/app/handlers_prewarming.go new file mode 100644 index 0000000..c51638f --- /dev/null +++ b/pkg/app/handlers_prewarming.go @@ -0,0 +1,94 @@ +package app + +import ( + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// handlePrewarmingStatus returns the status of the pre-warming worker +func (a *App) handlePrewarmingStatus(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + status := a.prewarmWorker.GetStatus() + + return c.Status(fiber.StatusOK).JSON(status) +} + +// handlePrewarmingTrigger manually triggers a pre-warming cycle +func (a *App) handlePrewarmingTrigger(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + ctx := c.Context() + a.prewarmWorker.TriggerPrewarm(ctx) + + log.Info().Msg("Pre-warming manually triggered via API") + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "message": "Pre-warming cycle triggered successfully", + }) +} + +// PrewarmPackageRequest represents a request to pre-warm a specific package +type PrewarmPackageRequest struct { + Registry string `json:"registry"` + Name string `json:"name"` + Version string `json:"version"` +} + +// handlePrewarmingPackage pre-warms a specific package +func (a *App) handlePrewarmingPackage(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + + var req PrewarmPackageRequest + if err := c.BodyParser(&req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid JSON in request body", + }) + } + + // Validate request + if req.Registry == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "registry is required", + }) + } + if req.Name == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "name is required", + }) + } + if req.Version == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "version is required", + }) + } + + ctx := c.Context() + err := a.prewarmWorker.PrewarmPackage(ctx, req.Registry, req.Name, req.Version) + if err != nil { + log.Error(). + Err(err). + Str("registry", req.Registry). + Str("name", req.Name). + Str("version", req.Version). + Msg("Failed to pre-warm package") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": "failed to pre-warm package", + }) + } + + log.Info(). + Str("registry", req.Registry). + Str("name", req.Name). + Str("version", req.Version). + Msg("Package pre-warmed via API") + + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "message": "Package pre-warmed successfully", + "package": fiber.Map{ + "registry": req.Registry, + "name": req.Name, + "version": req.Version, + }, + }) +} diff --git a/pkg/app/handlers_prewarming_test.go b/pkg/app/handlers_prewarming_test.go new file mode 100644 index 0000000..941b2bc --- /dev/null +++ b/pkg/app/handlers_prewarming_test.go @@ -0,0 +1,162 @@ +package app + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/lukaszraczylo/gohoarder/pkg/prewarming" + "github.com/stretchr/testify/suite" +) + +type PrewarmingHandlersTestSuite struct { + suite.Suite + app *fiber.App + appInst *App + prewarmWorker *prewarming.Worker +} + +func (s *PrewarmingHandlersTestSuite) SetupTest() { + // Create pre-warming worker (disabled by default) + s.prewarmWorker = prewarming.NewWorker(prewarming.Config{ + Enabled: false, + MaxConcurrent: 5, + }) + + // Create app instance + s.appInst = &App{ + prewarmWorker: s.prewarmWorker, + } + + // Create Fiber app + s.app = fiber.New() + + // Register routes + s.app.Get("/api/admin/prewarming/status", s.appInst.handlePrewarmingStatus) + s.app.Post("/api/admin/prewarming/trigger", s.appInst.handlePrewarmingTrigger) + s.app.Post("/api/admin/prewarming/package", s.appInst.handlePrewarmingPackage) +} + +func (s *PrewarmingHandlersTestSuite) TearDownTest() { + if s.prewarmWorker != nil { + s.prewarmWorker.Stop() + } +} + +func TestPrewarmingHandlersTestSuite(t *testing.T) { + suite.Run(t, new(PrewarmingHandlersTestSuite)) +} + +func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingStatus() { + req := httptest.NewRequest("GET", "/api/admin/prewarming/status", nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(200, resp.StatusCode) + + var result struct { + Enabled bool `json:"enabled"` + Running bool `json:"running"` + QueueSize int `json:"queue_size"` + ActiveWorkers int `json:"active_workers"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.False(result.Enabled) // Disabled in test setup +} + +func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingTrigger() { + req := httptest.NewRequest("POST", "/api/admin/prewarming/trigger", nil) + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(200, resp.StatusCode) + + var result map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&result) + s.NoError(err) + s.Contains(result, "message") +} + +func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingPackage() { + tests := []struct { + requestBody map[string]string + name string + expectedStatus int + }{ + { + name: "prewarm npm package", + requestBody: map[string]string{ + "registry": "npm", + "name": "lodash", + "version": "4.17.21", + }, + expectedStatus: 200, + }, + { + name: "prewarm pypi package", + requestBody: map[string]string{ + "registry": "pypi", + "name": "requests", + "version": "2.28.0", + }, + expectedStatus: 200, + }, + { + name: "prewarm go package", + requestBody: map[string]string{ + "registry": "go", + "name": "github.com/stretchr/testify", + "version": "v1.8.0", + }, + expectedStatus: 200, + }, + { + name: "missing registry", + requestBody: map[string]string{ + "name": "lodash", + "version": "4.17.21", + }, + expectedStatus: 400, + }, + { + name: "missing name", + requestBody: map[string]string{ + "registry": "npm", + "version": "4.17.21", + }, + expectedStatus: 400, + }, + { + name: "missing version", + requestBody: map[string]string{ + "registry": "npm", + "name": "lodash", + }, + expectedStatus: 400, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + bodyBytes, err := json.Marshal(tt.requestBody) + s.Require().NoError(err) + + req := httptest.NewRequest("POST", "/api/admin/prewarming/package", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(tt.expectedStatus, resp.StatusCode) + }) + } +} + +func (s *PrewarmingHandlersTestSuite) TestHandlePrewarmingPackageInvalidJSON() { + req := httptest.NewRequest("POST", "/api/admin/prewarming/package", bytes.NewReader([]byte("invalid json"))) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.app.Test(req) + s.Require().NoError(err) + s.Equal(400, resp.StatusCode) +} diff --git a/pkg/auth/validation_cache.go b/pkg/auth/validation_cache.go index 4b9e316..c9ead4c 100644 --- a/pkg/auth/validation_cache.go +++ b/pkg/auth/validation_cache.go @@ -7,9 +7,9 @@ import ( // ValidationResult represents a cached credential validation result type ValidationResult struct { - Allowed bool ExpiresAt time.Time Reason string + Allowed bool } // ValidationCache caches credential validation results to reduce upstream checks diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index c350813..f0873d6 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/lukaszraczylo/gohoarder/pkg/analytics" "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/metadata" "github.com/lukaszraczylo/gohoarder/pkg/metrics" @@ -27,15 +28,21 @@ type ScannerInterface interface { CheckVulnerabilities(ctx context.Context, registry, packageName, version string) (blocked bool, reason string, err error) } +// AnalyticsInterface defines the interface for analytics tracking +type AnalyticsInterface interface { + TrackDownload(download analytics.PackageDownload) +} + // Manager coordinates caching operations between storage and metadata type Manager struct { - storage storage.StorageBackend - metadata metadata.MetadataStore - scanner ScannerInterface - config Config - sf singleflight.Group - mu sync.RWMutex - evicting bool + storage storage.StorageBackend + metadata metadata.MetadataStore + scanner ScannerInterface + analytics AnalyticsInterface + sf singleflight.Group + config Config + mu sync.RWMutex + evicting bool } // Config holds cache manager configuration @@ -48,15 +55,15 @@ type Config struct { // CacheEntry represents a cached package type CacheEntry struct { - Package *metadata.Package Data io.ReadCloser - FromCache bool + Package *metadata.Package UpstreamURL string CacheControl string + FromCache bool } // New creates a new cache manager -func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanner ScannerInterface, config Config) (*Manager, error) { +func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanner ScannerInterface, analytics AnalyticsInterface, config Config) (*Manager, error) { if storage == nil { return nil, errors.New(errors.ErrCodeInvalidConfig, "storage backend is required") } @@ -70,6 +77,11 @@ func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanne log.Info().Msg("Cache manager initialized with security scanning enabled") } + // Analytics is optional - can be nil if analytics tracking is disabled + if analytics != nil { + log.Info().Msg("Cache manager initialized with analytics tracking enabled") + } + if config.DefaultTTL == 0 { config.DefaultTTL = 7 * 24 * time.Hour // 7 days default } @@ -87,10 +99,11 @@ func New(storage storage.StorageBackend, metadata metadata.MetadataStore, scanne } manager := &Manager{ - storage: storage, - metadata: metadata, - scanner: scanner, - config: config, + storage: storage, + metadata: metadata, + scanner: scanner, + analytics: analytics, + config: config, } // Start background cleanup worker @@ -134,6 +147,11 @@ func (m *Manager) getOrFetch(ctx context.Context, registry, name, version string metrics.RecordCacheHit(registry) _ = m.metadata.UpdateDownloadCount(ctx, registry, name, version) // #nosec G104 -- Async update, error logged + // Track download in analytics if enabled + if m.analytics != nil { + m.trackDownload(registry, name, version, pkg.Size) + } + // Check for vulnerabilities if scanner is enabled if m.scanner != nil { blocked, reason, err := m.scanner.CheckVulnerabilities(ctx, registry, name, version) @@ -552,6 +570,21 @@ func (m *Manager) Health(ctx context.Context) error { return nil } +// trackDownload tracks a package download event in analytics +func (m *Manager) trackDownload(registry, name, version string, size int64) { + download := analytics.PackageDownload{ + Registry: registry, + Name: name, + Version: version, + Timestamp: time.Now(), + BytesSize: size, + ClientIP: "", // TODO: Extract from context if available + UserAgent: "", // TODO: Extract from context if available + } + + m.analytics.TrackDownload(download) +} + // Close closes the cache manager func (m *Manager) Close() error { var err error diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 67ef6b8..7d232a0 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -197,12 +197,12 @@ func (m *MockMetadataStore) AggregateDownloadData(ctx context.Context) error { // TestNew tests cache manager creation func TestNew(t *testing.T) { tests := []struct { - name string storage storage.StorageBackend metadata metadata.MetadataStore + name string + errContains string config Config wantErr bool - errContains string }{ // GOOD: Valid configuration { @@ -262,7 +262,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager, err := New(tt.storage, tt.metadata, nil, tt.config) + manager, err := New(tt.storage, tt.metadata, nil, nil, tt.config) if tt.wantErr { require.Error(t, err) @@ -295,15 +295,15 @@ func TestNew(t *testing.T) { // TestGet tests cache retrieval with various scenarios func TestGet(t *testing.T) { tests := []struct { + setupMock func(*MockStorageBackend, *MockMetadataStore) + fetchFunc func(context.Context) (io.ReadCloser, string, error) name string registry string packageName string version string - setupMock func(*MockStorageBackend, *MockMetadataStore) - fetchFunc func(context.Context) (io.ReadCloser, string, error) + errContains string wantFromCache bool wantErr bool - errContains string }{ // GOOD: Cache hit { @@ -489,7 +489,7 @@ func TestGet(t *testing.T) { tt.setupMock(mockStorage, mockMetadata) } - manager, err := New(mockStorage, mockMetadata, nil, Config{ + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{ DefaultTTL: 24 * time.Hour, CleanupInterval: 1 * time.Hour, }) @@ -523,13 +523,13 @@ func TestGet(t *testing.T) { // TestDelete tests package deletion func TestDelete(t *testing.T) { tests := []struct { + setupMock func(*MockStorageBackend, *MockMetadataStore) name string registry string packageName string version string - setupMock func(*MockStorageBackend, *MockMetadataStore) - wantErr bool errContains string + wantErr bool }{ // GOOD: Successful deletion { @@ -615,7 +615,7 @@ func TestDelete(t *testing.T) { tt.setupMock(mockStorage, mockMetadata) } - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) ctx := context.Background() @@ -639,10 +639,10 @@ func TestDelete(t *testing.T) { // TestHealth tests health check functionality func TestHealth(t *testing.T) { tests := []struct { - name string setupMock func(*MockStorageBackend, *MockMetadataStore) - wantErr bool + name string errContains string + wantErr bool }{ // GOOD: Both healthy { @@ -692,7 +692,7 @@ func TestHealth(t *testing.T) { tt.setupMock(mockStorage, mockMetadata) } - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) ctx := context.Background() @@ -727,7 +727,7 @@ func TestGetStats(t *testing.T) { mockMetadata.On("GetStats", mock.Anything, "npm").Return(expectedStats, nil) - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) ctx := context.Background() @@ -741,8 +741,8 @@ func TestGetStats(t *testing.T) { // TestClose tests manager cleanup func TestClose(t *testing.T) { tests := []struct { - name string setupMock func(*MockStorageBackend, *MockMetadataStore) + name string wantErr bool }{ // GOOD: Clean close @@ -792,7 +792,7 @@ func TestClose(t *testing.T) { tt.setupMock(mockStorage, mockMetadata) } - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) err = manager.Close() // #nosec G104 -- Cleanup, error not critical @@ -812,11 +812,11 @@ func TestClose(t *testing.T) { // TestEvict tests LRU eviction func TestEvict(t *testing.T) { tests := []struct { - name string - needed int64 setupMock func(*MockStorageBackend, *MockMetadataStore) - wantErr bool + name string errContains string + needed int64 + wantErr bool }{ // GOOD: Successful eviction { @@ -881,7 +881,7 @@ func TestEvict(t *testing.T) { tt.setupMock(mockStorage, mockMetadata) } - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) ctx := context.Background() @@ -907,7 +907,7 @@ func TestGenerateStorageKey(t *testing.T) { mockStorage := &MockStorageBackend{} mockMetadata := &MockMetadataStore{} - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) tests := []struct { @@ -954,7 +954,7 @@ func TestConcurrentGet(t *testing.T) { io.NopCloser(bytes.NewReader([]byte("test data"))), nil).Maybe() mockMetadata.On("UpdateDownloadCount", mock.Anything, "npm", "concurrent", "1.0.0").Return(nil).Maybe() - manager, err := New(mockStorage, mockMetadata, nil, Config{}) + manager, err := New(mockStorage, mockMetadata, nil, nil, Config{}) require.NoError(t, err) ctx := context.Background() diff --git a/pkg/cdn/cdn.go b/pkg/cdn/cdn.go index 645d19b..ec1c01b 100644 --- a/pkg/cdn/cdn.go +++ b/pkg/cdn/cdn.go @@ -4,10 +4,7 @@ import ( "crypto/md5" // #nosec G501 -- MD5 used for ETag generation, not cryptographic security "encoding/hex" "fmt" - "io" "net/http" - "strconv" - "time" "github.com/rs/zerolog/log" ) @@ -214,38 +211,11 @@ func (m *Middleware) generateETag(body []byte) string { return `"` + hex.EncodeToString(hash[:]) + `"` } -// SetLastModified sets the Last-Modified header -func SetLastModified(w http.ResponseWriter, t time.Time) { - w.Header().Set("Last-Modified", t.UTC().Format(http.TimeFormat)) -} - -// SetCacheControl sets a custom Cache-Control header -func SetCacheControl(w http.ResponseWriter, cc CacheControl) { - w.Header().Set("Cache-Control", cc.String()) -} - -// SetNoCache sets headers to prevent caching -func SetNoCache(w http.ResponseWriter) { - w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") - w.Header().Set("Pragma", "no-cache") - w.Header().Set("Expires", "0") -} - -// SetImmutable sets headers for immutable content (content-addressed files) -func SetImmutable(w http.ResponseWriter, maxAge int) { - cc := CacheControl{ - Public: true, - MaxAge: maxAge, - Immutable: true, - } - w.Header().Set("Cache-Control", cc.String()) -} - // responseWriter wraps http.ResponseWriter to capture response type responseWriter struct { http.ResponseWriter - statusCode int body []byte + statusCode int } func (rw *responseWriter) WriteHeader(statusCode int) { @@ -261,100 +231,3 @@ func (rw *responseWriter) Write(b []byte) (int, error) { rw.body = append(rw.body, b...) return rw.ResponseWriter.Write(b) } - -// HandleRange handles HTTP Range requests for partial content -func HandleRange(w http.ResponseWriter, r *http.Request, content io.ReadSeeker, size int64, modTime time.Time) error { - // Set Last-Modified header - SetLastModified(w, modTime) - - // Check for Range header - rangeHeader := r.Header.Get("Range") - if rangeHeader == "" { - // No range request - serve full content - w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) - w.Header().Set("Accept-Ranges", "bytes") - w.WriteHeader(http.StatusOK) - _, err := io.Copy(w, content) - return err - } - - // Parse range header (simplified - only handles single range) - // Format: bytes=start-end - var start, end int64 - n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end) - if err != nil || n != 2 { - // Invalid range - serve full content - w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) - w.Header().Set("Accept-Ranges", "bytes") - w.WriteHeader(http.StatusOK) - _, err := io.Copy(w, content) - return err - } - - // Validate range - if start < 0 || start >= size || end < start || end >= size { - w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) - w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) - return nil - } - - // Seek to start position - if _, err := content.Seek(start, io.SeekStart); err != nil { - return err - } - - // Calculate content length - contentLength := end - start + 1 - - // Set headers for partial content - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, size)) - w.Header().Set("Content-Length", strconv.FormatInt(contentLength, 10)) - w.Header().Set("Accept-Ranges", "bytes") - w.WriteHeader(http.StatusPartialContent) - - // Copy range to response - _, err = io.CopyN(w, content, contentLength) - return err -} - -// DefaultCacheControl returns sensible defaults for different content types -func DefaultCacheControl(contentType string, versioned bool) CacheControl { - if versioned { - // Content-addressed or versioned resources can be cached forever - return CacheControl{ - Public: true, - MaxAge: 31536000, // 1 year - Immutable: true, - } - } - - // Default caching based on content type - switch contentType { - case "application/json": - return CacheControl{ - Public: true, - MaxAge: 3600, // 1 hour - SMaxAge: 7200, // 2 hours for shared caches - } - case "application/octet-stream", "application/x-gzip", "application/zip": - // Binary packages - return CacheControl{ - Public: true, - MaxAge: 86400, // 1 day - SMaxAge: 604800, // 1 week for shared caches - } - case "text/html": - // HTML should revalidate - return CacheControl{ - Public: true, - MaxAge: 0, - MustRevalidate: true, - } - default: - return CacheControl{ - Public: true, - MaxAge: 3600, // 1 hour default - SMaxAge: 7200, - } - } -} diff --git a/pkg/cdn/cdn_test.go b/pkg/cdn/cdn_test.go new file mode 100644 index 0000000..b60d6c5 --- /dev/null +++ b/pkg/cdn/cdn_test.go @@ -0,0 +1,299 @@ +package cdn + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" +) + +type CDNMiddlewareTestSuite struct { + suite.Suite + middleware *Middleware +} + +func (s *CDNMiddlewareTestSuite) SetupTest() { + s.middleware = NewMiddleware(Config{ + DefaultCacheControl: CacheControl{ + Public: true, + MaxAge: 3600, + SMaxAge: 7200, + }, + EnableETag: true, + EnableVary: true, + }) +} + +func TestCDNMiddlewareTestSuite(t *testing.T) { + suite.Run(t, new(CDNMiddlewareTestSuite)) +} + +func (s *CDNMiddlewareTestSuite) TestCacheControlHeader() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response")) + }) + + wrappedHandler := s.middleware.Handler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + s.Equal(http.StatusOK, w.Code) + s.Contains(w.Header().Get("Cache-Control"), "public") + s.Contains(w.Header().Get("Cache-Control"), "max-age=3600") + s.Contains(w.Header().Get("Cache-Control"), "s-maxage=7200") +} + +func (s *CDNMiddlewareTestSuite) TestETagGeneration() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test response content")) + }) + + wrappedHandler := s.middleware.Handler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + s.Equal(http.StatusOK, w.Code) + etag := w.Header().Get("ETag") + s.NotEmpty(etag) + s.True(len(etag) > 0) +} + +func (s *CDNMiddlewareTestSuite) TestETagConsistencyAcrossRequests() { + responseBody := []byte("test response content") + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(responseBody) + }) + + wrappedHandler := s.middleware.Handler(handler) + + // First request to get ETag + req1 := httptest.NewRequest("GET", "/test", nil) + w1 := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w1, req1) + etag := w1.Header().Get("ETag") + s.NotEmpty(etag) + s.Equal(http.StatusOK, w1.Code) + + // Verify ETag is consistent for same content + req2 := httptest.NewRequest("GET", "/test", nil) + w2 := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w2, req2) + etag2 := w2.Header().Get("ETag") + s.Equal(etag, etag2, "ETag should be consistent for same content") +} + +func (s *CDNMiddlewareTestSuite) TestVaryHeader() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + + wrappedHandler := s.middleware.Handler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer token") + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + vary := w.Header().Get("Vary") + s.NotEmpty(vary) + s.Contains(vary, "Accept-Encoding") + s.Contains(vary, "Authorization") + s.Contains(vary, "Accept") +} + +func (s *CDNMiddlewareTestSuite) TestCacheControlString() { + tests := []struct { + name string + expected string + cc CacheControl + }{ + { + name: "public with max-age", + cc: CacheControl{ + Public: true, + MaxAge: 3600, + }, + expected: "public, max-age=3600", + }, + { + name: "private with no-cache", + cc: CacheControl{ + Private: true, + NoCache: true, + }, + expected: "private, no-cache", + }, + { + name: "immutable", + cc: CacheControl{ + Public: true, + MaxAge: 31536000, + Immutable: true, + }, + expected: "public, immutable, max-age=31536000", + }, + { + name: "no-store", + cc: CacheControl{ + NoStore: true, + }, + expected: "no-store", + }, + { + name: "must-revalidate", + cc: CacheControl{ + Public: true, + MustRevalidate: true, + }, + expected: "public, must-revalidate", + }, + { + name: "s-maxage", + cc: CacheControl{ + Public: true, + MaxAge: 3600, + SMaxAge: 7200, + }, + expected: "public, max-age=3600, s-maxage=7200", + }, + { + name: "stale-while-revalidate", + cc: CacheControl{ + Public: true, + MaxAge: 3600, + StaleWhileRevalidate: 86400, + }, + expected: "public, max-age=3600, stale-while-revalidate=86400", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := tt.cc.String() + // Check that all expected parts are in the result + for _, part := range splitCacheControl(tt.expected) { + s.Contains(result, part) + } + }) + } +} + +func (s *CDNMiddlewareTestSuite) TestGenerateETag() { + tests := []struct { + name string + body []byte + expected bool // true if ETag should be generated + }{ + { + name: "non-empty body", + body: []byte("test content"), + expected: true, + }, + { + name: "empty body", + body: []byte{}, + expected: true, // Empty body still generates ETag (MD5 of empty string) + }, + { + name: "nil body", + body: nil, + expected: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + etag := s.middleware.generateETag(tt.body) + if tt.expected { + s.NotEmpty(etag) + s.True(len(etag) > 2) // Should be quoted + } else { + s.Empty(etag) + } + }) + } +} + +func (s *CDNMiddlewareTestSuite) TestETagConsistency() { + // Same content should produce same ETag + body := []byte("consistent content") + etag1 := s.middleware.generateETag(body) + etag2 := s.middleware.generateETag(body) + + s.Equal(etag1, etag2) + + // Different content should produce different ETag + body2 := []byte("different content") + etag3 := s.middleware.generateETag(body2) + + s.NotEqual(etag1, etag3) +} + +func (s *CDNMiddlewareTestSuite) TestNoCacheFor4xxErrors() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + }) + + wrappedHandler := s.middleware.Handler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + s.Equal(http.StatusNotFound, w.Code) + // 4xx errors should not have cache headers applied + // (based on the middleware only applying headers for 2xx status codes) +} + +func (s *CDNMiddlewareTestSuite) TestNoCacheFor5xxErrors() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("error")) + }) + + wrappedHandler := s.middleware.Handler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + s.Equal(http.StatusInternalServerError, w.Code) + // 5xx errors should not have cache headers applied +} + +// Helper function to split cache-control string +func splitCacheControl(s string) []string { + var parts []string + current := "" + for _, char := range s { + if char == ',' { + if current != "" { + parts = append(parts, current) + current = "" + } + } else if char != ' ' { + current += string(char) + } + } + if current != "" { + parts = append(parts, current) + } + return parts +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 357d824..581984a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -7,42 +7,42 @@ import ( // Config is the main configuration struct type Config struct { - Server ServerConfig `mapstructure:"server" json:"server"` Storage StorageConfig `mapstructure:"storage" json:"storage"` - Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"` Cache CacheConfig `mapstructure:"cache" json:"cache"` Security SecurityConfig `mapstructure:"security" json:"security"` - Auth AuthConfig `mapstructure:"auth" json:"auth"` - Network NetworkConfig `mapstructure:"network" json:"network"` - Logging LoggingConfig `mapstructure:"logging" json:"logging"` + Metadata MetadataConfig `mapstructure:"metadata" json:"metadata"` Handlers HandlersConfig `mapstructure:"handlers" json:"handlers"` + Server ServerConfig `mapstructure:"server" json:"server"` + Logging LoggingConfig `mapstructure:"logging" json:"logging"` + Network NetworkConfig `mapstructure:"network" json:"network"` + Auth AuthConfig `mapstructure:"auth" json:"auth"` } // ServerConfig contains HTTP server configuration type ServerConfig struct { + TLS TLSConfig `mapstructure:"tls" json:"tls"` Host string `mapstructure:"host" json:"host"` Port int `mapstructure:"port" json:"port"` ReadTimeout time.Duration `mapstructure:"read_timeout" json:"read_timeout"` WriteTimeout time.Duration `mapstructure:"write_timeout" json:"write_timeout"` IdleTimeout time.Duration `mapstructure:"idle_timeout" json:"idle_timeout"` - TLS TLSConfig `mapstructure:"tls" json:"tls"` } // TLSConfig contains TLS/HTTPS configuration type TLSConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` CertFile string `mapstructure:"cert_file" json:"cert_file"` KeyFile string `mapstructure:"key_file" json:"key_file"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } // StorageConfig contains storage backend configuration type StorageConfig struct { - Backend string `mapstructure:"backend" json:"backend"` // filesystem, s3, smb, nfs + Options map[string]interface{} `mapstructure:"options" json:"options"` + SMB SMBConfig `mapstructure:"smb" json:"smb"` + Backend string `mapstructure:"backend" json:"backend"` Path string `mapstructure:"path" json:"path"` Filesystem FilesystemConfig `mapstructure:"filesystem" json:"filesystem"` S3 S3Config `mapstructure:"s3" json:"s3"` - SMB SMBConfig `mapstructure:"smb" json:"smb"` - Options map[string]interface{} `mapstructure:"options" json:"options"` } // FilesystemConfig contains local filesystem storage configuration @@ -52,12 +52,14 @@ type FilesystemConfig struct { // S3Config contains S3-compatible storage configuration type S3Config struct { - Endpoint string `mapstructure:"endpoint" json:"endpoint"` - Region string `mapstructure:"region" json:"region"` - Bucket string `mapstructure:"bucket" json:"bucket"` - AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"` - SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // Don't serialize secrets - UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"` + Endpoint string `mapstructure:"endpoint" json:"endpoint"` // Optional: for MinIO, etc. + Region string `mapstructure:"region" json:"region"` // AWS region (e.g., us-east-1) + Bucket string `mapstructure:"bucket" json:"bucket"` // S3 bucket name + Prefix string `mapstructure:"prefix" json:"prefix"` // Optional: key prefix + AccessKeyID string `mapstructure:"access_key_id" json:"access_key_id"` // AWS access key + SecretAccessKey string `mapstructure:"secret_access_key" json:"-"` // AWS secret key (not serialized) + ForcePathStyle bool `mapstructure:"force_path_style" json:"force_path_style"` // For MinIO compatibility + UseSSL bool `mapstructure:"use_ssl" json:"use_ssl"` // Deprecated: use endpoint with https:// } // SMBConfig contains SMB/CIFS storage configuration @@ -71,10 +73,10 @@ type SMBConfig struct { // MetadataConfig contains metadata store configuration type MetadataConfig struct { - Backend string `mapstructure:"backend" json:"backend"` // sqlite, postgresql, file + PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"` + Backend string `mapstructure:"backend" json:"backend"` Connection string `mapstructure:"connection" json:"connection"` SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"` - PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"` } // SQLiteConfig contains SQLite-specific configuration @@ -86,33 +88,33 @@ type SQLiteConfig struct { // PostgreSQLConfig contains PostgreSQL-specific configuration type PostgreSQLConfig struct { Host string `mapstructure:"host" json:"host"` - Port int `mapstructure:"port" json:"port"` Database string `mapstructure:"database" json:"database"` User string `mapstructure:"user" json:"user"` - Password string `mapstructure:"password" json:"-"` // Don't serialize secrets + Password string `mapstructure:"password" json:"-"` SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"` + Port int `mapstructure:"port" json:"port"` } // CacheConfig contains cache management configuration type CacheConfig struct { + TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"` DefaultTTL time.Duration `mapstructure:"default_ttl" json:"default_ttl"` CleanupInterval time.Duration `mapstructure:"cleanup_interval" json:"cleanup_interval"` MaxSizeBytes int64 `mapstructure:"max_size_bytes" json:"max_size_bytes"` PerProjectQuota int64 `mapstructure:"per_project_quota" json:"per_project_quota"` - TTLOverrides map[string]time.Duration `mapstructure:"ttl_overrides" json:"ttl_overrides"` // Per ecosystem } // SecurityConfig contains security scanning configuration type SecurityConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` - ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"` // Scan packages on first download - RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"` // How often to re-scan (e.g., 24h, 168h for weekly) - BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"` // none, low, medium, high, critical - BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"` // Max vulns per severity before blocking - UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"` // Update vulnerability databases on startup - AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"` // Packages that bypass security checks (format: "registry/name@version" or "registry/name") - IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"` // CVE IDs to ignore globally (e.g., "CVE-2021-23337") Scanners ScannersConfig `mapstructure:"scanners" json:"scanners"` + BlockOnSeverity string `mapstructure:"block_on_severity" json:"block_on_severity"` + AllowedPackages []string `mapstructure:"allowed_packages" json:"allowed_packages"` + IgnoredCVEs []string `mapstructure:"ignored_cves" json:"ignored_cves"` + BlockThresholds VulnerabilityThresholds `mapstructure:"block_thresholds" json:"block_thresholds"` + RescanInterval time.Duration `mapstructure:"rescan_interval" json:"rescan_interval"` + Enabled bool `mapstructure:"enabled" json:"enabled"` + ScanOnDownload bool `mapstructure:"scan_on_download" json:"scan_on_download"` + UpdateDBOnStartup bool `mapstructure:"update_db_on_startup" json:"update_db_on_startup"` } // VulnerabilityThresholds defines max allowed vulnerabilities per severity @@ -126,36 +128,36 @@ type VulnerabilityThresholds struct { // ScannersConfig contains individual scanner configurations type ScannersConfig struct { Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"` - OSV OSVConfig `mapstructure:"osv" json:"osv"` + GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"` Static StaticConfig `mapstructure:"static" json:"static"` + OSV OSVConfig `mapstructure:"osv" json:"osv"` Grype GrypeConfig `mapstructure:"grype" json:"grype"` Govulncheck GovulncheckConfig `mapstructure:"govulncheck" json:"govulncheck"` NpmAudit NpmAuditConfig `mapstructure:"npm_audit" json:"npm_audit"` PipAudit PipAuditConfig `mapstructure:"pip_audit" json:"pip_audit"` - GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"` } // TrivyConfig contains Trivy scanner configuration type TrivyConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` - Timeout time.Duration `mapstructure:"timeout" json:"timeout"` CacheDB string `mapstructure:"cache_db" json:"cache_db"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } // OSVConfig contains OSV scanner configuration type OSVConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` APIURL string `mapstructure:"api_url" json:"api_url"` Timeout time.Duration `mapstructure:"timeout" json:"timeout"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } // StaticConfig contains static analysis configuration type StaticConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` + AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"` MaxPackageSize int64 `mapstructure:"max_package_size" json:"max_package_size"` + Enabled bool `mapstructure:"enabled" json:"enabled"` CheckChecksums bool `mapstructure:"check_checksums" json:"check_checksums"` BlockSuspicious bool `mapstructure:"block_suspicious" json:"block_suspicious"` - AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"` } // GrypeConfig contains Grype scanner configuration @@ -184,16 +186,16 @@ type PipAuditConfig struct { // GHSAConfig contains GitHub Advisory Database scanner configuration type GHSAConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` + Token string `mapstructure:"token" json:"-"` Timeout time.Duration `mapstructure:"timeout" json:"timeout"` - Token string `mapstructure:"token" json:"-"` // GitHub token for higher rate limits (don't serialize) + Enabled bool `mapstructure:"enabled" json:"enabled"` } // AuthConfig contains authentication configuration type AuthConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` KeyExpiration time.Duration `mapstructure:"key_expiration" json:"key_expiration"` BcryptCost int `mapstructure:"bcrypt_cost" json:"bcrypt_cost"` + Enabled bool `mapstructure:"enabled" json:"enabled"` AuditLog bool `mapstructure:"audit_log" json:"audit_log"` } @@ -245,24 +247,24 @@ type HandlersConfig struct { // GoHandlerConfig contains Go proxy configuration type GoHandlerConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` UpstreamProxy string `mapstructure:"upstream_proxy" json:"upstream_proxy"` ChecksumDB string `mapstructure:"checksum_db" json:"checksum_db"` + GitCredentialsFile string `mapstructure:"git_credentials_file" json:"git_credentials_file"` + Enabled bool `mapstructure:"enabled" json:"enabled"` VerifyChecksums bool `mapstructure:"verify_checksums" json:"verify_checksums"` - GitCredentialsFile string `mapstructure:"git_credentials_file" json:"git_credentials_file"` // Path to git credentials JSON file } // NPMHandlerConfig contains NPM registry configuration type NPMHandlerConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` UpstreamRegistry string `mapstructure:"upstream_registry" json:"upstream_registry"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } // PyPIHandlerConfig contains PyPI configuration type PyPIHandlerConfig struct { - Enabled bool `mapstructure:"enabled" json:"enabled"` UpstreamURL string `mapstructure:"upstream_url" json:"upstream_url"` SimpleAPIURL string `mapstructure:"simple_api_url" json:"simple_api_url"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } // Default returns a configuration with sensible defaults diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 211b585..b90f434 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -41,10 +41,10 @@ func (s *ConfigTestSuite) TestDefault() { func (s *ConfigTestSuite) TestValidate() { tests := []struct { - name string modify func(*Config) - expectError bool + name string errorSubstr string + expectError bool }{ { name: "valid_config", @@ -175,11 +175,11 @@ func (s *ConfigTestSuite) TestValidate() { func (s *ConfigTestSuite) TestLoad() { tests := []struct { + envVars map[string]string + validate func(*Config) name string configYAML string - envVars map[string]string expectError bool - validate func(*Config) }{ { name: "valid_yaml_config", @@ -319,13 +319,6 @@ func (s *ConfigTestSuite) TestLoadMissingFile() { s.Nil(cfg) } -func (s *ConfigTestSuite) TestLoadWithDefaults() { - // Invalid config path should return defaults - cfg := LoadWithDefaults("/invalid/path/config.yaml") - s.NotNil(cfg) - s.Equal(8080, cfg.Server.Port) -} - // Benchmark tests func BenchmarkDefault(b *testing.B) { for i := 0; i < b.N; i++ { @@ -344,8 +337,8 @@ func BenchmarkValidate(b *testing.B) { // Table-driven edge cases func TestConfigEdgeCases(t *testing.T) { tests := []struct { - name string config *Config + name string valid bool }{ { diff --git a/pkg/config/loader.go b/pkg/config/loader.go index 5800df8..067884c 100644 --- a/pkg/config/loader.go +++ b/pkg/config/loader.go @@ -51,12 +51,3 @@ func Load(configPath string) (*Config, error) { return cfg, nil } - -// LoadWithDefaults loads configuration or returns defaults on error -func LoadWithDefaults(configPath string) *Config { - cfg, err := Load(configPath) - if err != nil { - return Default() - } - return cfg -} diff --git a/pkg/errors/codes.go b/pkg/errors/codes.go index 443dfa0..c87aa77 100644 --- a/pkg/errors/codes.go +++ b/pkg/errors/codes.go @@ -58,11 +58,3 @@ var HTTPStatusCode = map[string]int{ ErrCodeServiceUnavailable: 503, ErrCodeCircuitOpen: 503, } - -// GetHTTPStatus returns the HTTP status code for an error code -func GetHTTPStatus(code string) int { - if status, ok := HTTPStatusCode[code]; ok { - return status - } - return 500 // Default to internal server error -} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index cebb637..803e802 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -6,11 +6,11 @@ import ( // Error represents a structured error with code and details type Error struct { + Details interface{} `json:"details,omitempty"` + Cause error `json:"-"` Code string `json:"code"` Message string `json:"message"` - Details interface{} `json:"details,omitempty"` Trace []string `json:"trace,omitempty"` - Cause error `json:"-"` // Internal cause, not serialized } // Error implements the error interface @@ -34,26 +34,12 @@ func New(code, message string) *Error { } } -// Newf creates a new error with formatted message -func Newf(code, format string, args ...interface{}) *Error { - return &Error{ - Code: code, - Message: fmt.Sprintf(format, args...), - } -} - // WithDetails adds details to the error func (e *Error) WithDetails(details interface{}) *Error { e.Details = details return e } -// WithTrace adds stack trace to the error -func (e *Error) WithTrace(trace []string) *Error { - e.Trace = trace - return e -} - // WithCause adds an underlying cause to the error func (e *Error) WithCause(cause error) *Error { e.Cause = cause @@ -69,44 +55,12 @@ func Wrap(err error, code, message string) *Error { } } -// Wrapf wraps an existing error with formatted message -func Wrapf(err error, code, format string, args ...interface{}) *Error { - return &Error{ - Code: code, - Message: fmt.Sprintf(format, args...), - Cause: err, - } -} - -// Common error constructors -func BadRequest(message string) *Error { - return New(ErrCodeBadRequest, message) -} - -func Unauthorized(message string) *Error { - return New(ErrCodeUnauthorized, message) -} - -func Forbidden(message string) *Error { - return New(ErrCodeForbidden, message) -} - +// NotFound creates a not found error func NotFound(message string) *Error { return New(ErrCodeNotFound, message) } -func InternalServer(message string) *Error { - return New(ErrCodeInternalServer, message) -} - -func PackageNotFound(name, version string) *Error { - return New(ErrCodePackageNotFound, fmt.Sprintf("Package %s@%s not found", name, version)). - WithDetails(map[string]string{ - "package": name, - "version": version, - }) -} - +// QuotaExceeded creates a quota exceeded error func QuotaExceeded(limit int64) *Error { return New(ErrCodeQuotaExceeded, "Storage quota exceeded"). WithDetails(map[string]interface{}{ diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go index 77885f0..b779f07 100644 --- a/pkg/errors/errors_test.go +++ b/pkg/errors/errors_test.go @@ -4,7 +4,6 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -46,43 +45,10 @@ func (s *ErrorsTestSuite) TestNew() { } } -func (s *ErrorsTestSuite) TestNewf() { - tests := []struct { - name string - code string - format string - args []interface{} - expected string - }{ - { - name: "formatted_message", - code: ErrCodePackageNotFound, - format: "Package %s@%s not found", - args: []interface{}{"react", "18.2.0"}, - expected: "Package react@18.2.0 not found", - }, - { - name: "no_args", - code: ErrCodeInternalServer, - format: "Internal error", - args: []interface{}{}, - expected: "Internal error", - }, - } - - for _, tt := range tests { - s.Run(tt.name, func() { - err := Newf(tt.code, tt.format, tt.args...) - s.Equal(tt.code, err.Code) - s.Equal(tt.expected, err.Message) - }) - } -} - func (s *ErrorsTestSuite) TestWithDetails() { tests := []struct { - name string details interface{} + name string }{ { name: "map_details", @@ -106,12 +72,6 @@ func (s *ErrorsTestSuite) TestWithDetails() { } } -func (s *ErrorsTestSuite) TestWithTrace() { - trace := []string{"file1.go:10", "file2.go:20"} - err := New(ErrCodeInternalServer, "test").WithTrace(trace) - s.Equal(trace, err.Trace) -} - func (s *ErrorsTestSuite) TestWithCause() { cause := errors.New("underlying error") err := New(ErrCodeStorageFailure, "test").WithCause(cause) @@ -129,15 +89,6 @@ func (s *ErrorsTestSuite) TestWrap() { s.True(errors.Is(wrapped, cause)) } -func (s *ErrorsTestSuite) TestWrapf() { - cause := errors.New("connection refused") - wrapped := Wrapf(cause, ErrCodeUpstreamFailure, "failed to connect to %s", "registry.npmjs.org") - - s.Equal(ErrCodeUpstreamFailure, wrapped.Code) - s.Equal("failed to connect to registry.npmjs.org", wrapped.Message) - s.Equal(cause, wrapped.Cause) -} - func (s *ErrorsTestSuite) TestErrorString() { tests := []struct { name string @@ -163,59 +114,6 @@ func (s *ErrorsTestSuite) TestErrorString() { } } -func (s *ErrorsTestSuite) TestCommonConstructors() { - tests := []struct { - name string - fn func() *Error - wantCode string - }{ - { - name: "bad_request", - fn: func() *Error { return BadRequest("invalid input") }, - wantCode: ErrCodeBadRequest, - }, - { - name: "unauthorized", - fn: func() *Error { return Unauthorized("invalid token") }, - wantCode: ErrCodeUnauthorized, - }, - { - name: "forbidden", - fn: func() *Error { return Forbidden("access denied") }, - wantCode: ErrCodeForbidden, - }, - { - name: "not_found", - fn: func() *Error { return NotFound("resource missing") }, - wantCode: ErrCodeNotFound, - }, - { - name: "internal_server", - fn: func() *Error { return InternalServer("server error") }, - wantCode: ErrCodeInternalServer, - }, - } - - for _, tt := range tests { - s.Run(tt.name, func() { - err := tt.fn() - s.Equal(tt.wantCode, err.Code) - }) - } -} - -func (s *ErrorsTestSuite) TestPackageNotFound() { - err := PackageNotFound("lodash", "4.17.21") - s.Equal(ErrCodePackageNotFound, err.Code) - s.Equal("Package lodash@4.17.21 not found", err.Message) - s.NotNil(err.Details) - - details, ok := err.Details.(map[string]string) - s.True(ok) - s.Equal("lodash", details["package"]) - s.Equal("4.17.21", details["version"]) -} - func (s *ErrorsTestSuite) TestQuotaExceeded() { limit := int64(1000000) err := QuotaExceeded(limit) @@ -275,31 +173,3 @@ func (s *ErrorsTestSuite) TestEdgeCases() { s.Equal(largeDetails, err.Details) }) } - -// Table-driven test for error codes -func TestGetHTTPStatus(t *testing.T) { - tests := []struct { - code string - expectedStatus int - }{ - {ErrCodeBadRequest, 400}, - {ErrCodeUnauthorized, 401}, - {ErrCodeForbidden, 403}, - {ErrCodeNotFound, 404}, - {ErrCodeConflict, 409}, - {ErrCodePayloadTooLarge, 413}, - {ErrCodeChecksumMismatch, 422}, - {ErrCodeRateLimited, 429}, - {ErrCodeInternalServer, 500}, - {ErrCodeDatabaseFailure, 500}, - {ErrCodeUpstreamFailure, 502}, - {ErrCodeServiceUnavailable, 503}, - {"UNKNOWN_CODE", 500}, // Default - } - - for _, tt := range tests { - t.Run(tt.code, func(t *testing.T) { - assert.Equal(t, tt.expectedStatus, GetHTTPStatus(tt.code)) - }) - } -} diff --git a/pkg/errors/response.go b/pkg/errors/response.go deleted file mode 100644 index 362c85d..0000000 --- a/pkg/errors/response.go +++ /dev/null @@ -1,90 +0,0 @@ -package errors - -import ( - "net/http" - "time" - - json "github.com/goccy/go-json" -) - -// Response is the standard API response envelope -type Response struct { - Success bool `json:"success"` - Data interface{} `json:"data,omitempty"` - Error *ErrorResponse `json:"error,omitempty"` - Metadata *ResponseMeta `json:"metadata,omitempty"` -} - -// ErrorResponse contains error details -type ErrorResponse struct { - Code string `json:"code"` - Message string `json:"message"` - Details interface{} `json:"details,omitempty"` - Trace []string `json:"trace,omitempty"` -} - -// ResponseMeta contains request metadata -type ResponseMeta struct { - RequestID string `json:"request_id"` - Timestamp string `json:"timestamp"` - Duration string `json:"duration,omitempty"` - Version string `json:"version"` -} - -// WriteJSON writes a success response as JSON -func WriteJSON(w http.ResponseWriter, statusCode int, data interface{}, meta *ResponseMeta) { - response := Response{ - Success: statusCode < 400, - Data: data, - Metadata: meta, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - - if err := json.NewEncoder(w).Encode(response); err != nil { - // Fallback to simple error response - http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode response"}}`, http.StatusInternalServerError) - } -} - -// WriteError writes an error response as JSON -func WriteError(w http.ResponseWriter, statusCode int, err *Error, meta *ResponseMeta) { - errResp := &ErrorResponse{ - Code: err.Code, - Message: err.Message, - Details: err.Details, - Trace: err.Trace, - } - - response := Response{ - Success: false, - Error: errResp, - Metadata: meta, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - - if encErr := json.NewEncoder(w).Encode(response); encErr != nil { - // Fallback to simple error response - http.Error(w, `{"success":false,"error":{"code":"ENCODING_ERROR","message":"Failed to encode error response"}}`, http.StatusInternalServerError) - } -} - -// WriteErrorSimple writes an error without metadata -func WriteErrorSimple(w http.ResponseWriter, err *Error) { - statusCode := GetHTTPStatus(err.Code) - meta := &ResponseMeta{ - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - WriteError(w, statusCode, err, meta) -} - -// WriteJSONSimple writes a success response without metadata -func WriteJSONSimple(w http.ResponseWriter, statusCode int, data interface{}) { - meta := &ResponseMeta{ - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - WriteJSON(w, statusCode, data, meta) -} diff --git a/pkg/health/health.go b/pkg/health/health.go index 1504427..817b0a5 100644 --- a/pkg/health/health.go +++ b/pkg/health/health.go @@ -21,25 +21,25 @@ const ( // Check represents a single health check type Check struct { + Fn func(context.Context) (Status, string) `json:"-"` Name string `json:"name"` Status Status `json:"status"` Error string `json:"error,omitempty"` - Fn func(context.Context) (Status, string) `json:"-"` } // Response is the health check response type Response struct { - Success bool `json:"success"` Data *HealthData `json:"data,omitempty"` Metadata *Metadata `json:"metadata,omitempty"` + Success bool `json:"success"` } // HealthData contains health check data type HealthData struct { + Components map[string]*Component `json:"components"` Status Status `json:"status"` Version string `json:"version"` Uptime string `json:"uptime"` - Components map[string]*Component `json:"components"` } // Component represents a system component @@ -57,8 +57,8 @@ type Metadata struct { // Checker manages health checks type Checker struct { - checks []*Check startTime time.Time + checks []*Check mu sync.RWMutex } diff --git a/pkg/lock/redis.go b/pkg/lock/redis.go deleted file mode 100644 index ee13873..0000000 --- a/pkg/lock/redis.go +++ /dev/null @@ -1,275 +0,0 @@ -package lock - -import ( - "context" - "crypto/rand" - "encoding/hex" - "errors" - "time" - - "github.com/redis/go-redis/v9" - "github.com/rs/zerolog/log" -) - -var ( - ErrLockNotAcquired = errors.New("lock not acquired") - ErrLockNotHeld = errors.New("lock not held by this instance") - ErrInvalidTTL = errors.New("invalid TTL: must be positive") -) - -// Lock represents a distributed lock -type Lock struct { - client *redis.Client - key string - value string - ttl time.Duration -} - -// Manager manages distributed locks using Redis -type Manager struct { - client *redis.Client -} - -// Config holds Redis connection configuration -type Config struct { - Addr string - Password string - DB int -} - -// NewManager creates a new lock manager -func NewManager(cfg Config) (*Manager, error) { - client := redis.NewClient(&redis.Options{ - Addr: cfg.Addr, - Password: cfg.Password, - DB: cfg.DB, - }) - - // Test connection - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := client.Ping(ctx).Err(); err != nil { - return nil, err - } - - log.Info(). - Str("addr", cfg.Addr). - Int("db", cfg.DB). - Msg("Connected to Redis for distributed locking") - - return &Manager{ - client: client, - }, nil -} - -// Acquire attempts to acquire a lock with the given key and TTL -// Returns a Lock instance if successful, or an error if the lock is already held -func (m *Manager) Acquire(ctx context.Context, key string, ttl time.Duration) (*Lock, error) { - if ttl <= 0 { - return nil, ErrInvalidTTL - } - - // Generate unique value for this lock instance - value, err := generateLockValue() - if err != nil { - return nil, err - } - - // Try to acquire lock using SET NX (set if not exists) - success, err := m.client.SetNX(ctx, key, value, ttl).Result() - if err != nil { - log.Error(). - Err(err). - Str("key", key). - Msg("Failed to acquire lock") - return nil, err - } - - if !success { - log.Debug(). - Str("key", key). - Msg("Lock already held by another instance") - return nil, ErrLockNotAcquired - } - - log.Debug(). - Str("key", key). - Dur("ttl", ttl). - Msg("Lock acquired successfully") - - return &Lock{ - client: m.client, - key: key, - value: value, - ttl: ttl, - }, nil -} - -// TryAcquire attempts to acquire a lock, retrying for the specified duration -// Returns a Lock instance if successful within the timeout, or an error -func (m *Manager) TryAcquire(ctx context.Context, key string, ttl, timeout time.Duration) (*Lock, error) { - if ttl <= 0 { - return nil, ErrInvalidTTL - } - - deadline := time.Now().Add(timeout) - ticker := time.NewTicker(50 * time.Millisecond) - defer ticker.Stop() - - for { - lock, err := m.Acquire(ctx, key, ttl) - if err == nil { - return lock, nil - } - - if err != ErrLockNotAcquired { - return nil, err - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - if time.Now().After(deadline) { - return nil, ErrLockNotAcquired - } - } - } -} - -// Release releases the lock -// Returns an error if the lock is not held by this instance -func (l *Lock) Release(ctx context.Context) error { - // Use Lua script to ensure atomic check-and-delete - // Only delete if the value matches (ensures we own the lock) - script := redis.NewScript(` - if redis.call("get", KEYS[1]) == ARGV[1] then - return redis.call("del", KEYS[1]) - else - return 0 - end - `) - - result, err := script.Run(ctx, l.client, []string{l.key}, l.value).Result() - if err != nil { - log.Error(). - Err(err). - Str("key", l.key). - Msg("Failed to release lock") - return err - } - - // Result of 0 means the lock was not deleted (not owned by us) - if result.(int64) == 0 { - log.Warn(). - Str("key", l.key). - Msg("Attempted to release lock not held by this instance") - return ErrLockNotHeld - } - - log.Debug(). - Str("key", l.key). - Msg("Lock released successfully") - - return nil -} - -// Extend extends the lock TTL -// Returns an error if the lock is not held by this instance -func (l *Lock) Extend(ctx context.Context, additionalTTL time.Duration) error { - // Use Lua script to ensure atomic check-and-extend - script := redis.NewScript(` - if redis.call("get", KEYS[1]) == ARGV[1] then - return redis.call("expire", KEYS[1], ARGV[2]) - else - return 0 - end - `) - - newTTL := l.ttl + additionalTTL - result, err := script.Run(ctx, l.client, []string{l.key}, l.value, int(newTTL.Seconds())).Result() - if err != nil { - log.Error(). - Err(err). - Str("key", l.key). - Msg("Failed to extend lock") - return err - } - - if result.(int64) == 0 { - log.Warn(). - Str("key", l.key). - Msg("Attempted to extend lock not held by this instance") - return ErrLockNotHeld - } - - l.ttl = newTTL - log.Debug(). - Str("key", l.key). - Dur("new_ttl", newTTL). - Msg("Lock TTL extended") - - return nil -} - -// IsHeld checks if the lock is still held by this instance -func (l *Lock) IsHeld(ctx context.Context) bool { - value, err := l.client.Get(ctx, l.key).Result() - if err != nil { - return false - } - return value == l.value -} - -// Close closes the lock manager and its Redis connection -func (m *Manager) Close() error { - return m.client.Close() // #nosec G104 -- Cleanup, error not critical -} - -// generateLockValue generates a cryptographically random lock value -func generateLockValue() (string, error) { - bytes := make([]byte, 16) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return hex.EncodeToString(bytes), nil -} - -// WithLock executes a function while holding a distributed lock -// The lock is automatically released when the function returns -func (m *Manager) WithLock(ctx context.Context, key string, ttl time.Duration, fn func(context.Context) error) error { - lock, err := m.Acquire(ctx, key, ttl) - if err != nil { - return err - } - defer func() { - if err := lock.Release(context.Background()); err != nil { - log.Error(). - Err(err). - Str("key", key). - Msg("Failed to release lock in defer") - } - }() - - return fn(ctx) -} - -// WithRetryLock executes a function while holding a distributed lock -// It retries acquisition for the specified timeout duration -func (m *Manager) WithRetryLock(ctx context.Context, key string, ttl, timeout time.Duration, fn func(context.Context) error) error { - lock, err := m.TryAcquire(ctx, key, ttl, timeout) - if err != nil { - return err - } - defer func() { - if err := lock.Release(context.Background()); err != nil { - log.Error(). - Err(err). - Str("key", key). - Msg("Failed to release lock in defer") - } - }() - - return fn(ctx) -} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index fb1a73e..9e167ba 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -35,23 +35,3 @@ func Init(cfg Config) error { return nil } - -// Get returns the global logger -func Get() *zerolog.Logger { - return &log.Logger -} - -// WithFields returns a logger with additional fields -func WithFields(fields map[string]interface{}) *zerolog.Logger { - logger := log.Logger - for k, v := range fields { - logger = logger.With().Interface(k, v).Logger() - } - return &logger -} - -// WithRequestID returns a logger with request ID -func WithRequestID(requestID string) *zerolog.Logger { - logger := log.With().Str("request_id", requestID).Logger() - return &logger -} diff --git a/pkg/logger/middleware.go b/pkg/logger/middleware.go deleted file mode 100644 index f7a5e1d..0000000 --- a/pkg/logger/middleware.go +++ /dev/null @@ -1,65 +0,0 @@ -package logger - -import ( - "net/http" - "time" - - "github.com/lukaszraczylo/gohoarder/pkg/uuid" - "github.com/rs/zerolog/log" -) - -// responseWriter wraps http.ResponseWriter to capture status code -type responseWriter struct { - http.ResponseWriter - statusCode int - written int64 -} - -func (rw *responseWriter) WriteHeader(code int) { - rw.statusCode = code - rw.ResponseWriter.WriteHeader(code) -} - -func (rw *responseWriter) Write(b []byte) (int, error) { - n, err := rw.ResponseWriter.Write(b) - rw.written += int64(n) - return n, err -} - -// Middleware is HTTP middleware for request logging -func Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - - // Generate request ID - requestID := r.Header.Get("X-Request-ID") - if requestID == "" { - requestID = "req_" + uuid.New().String()[:8] - } - - // Wrap response writer - rw := &responseWriter{ - ResponseWriter: w, - statusCode: http.StatusOK, - } - - // Set request ID in response header - rw.Header().Set("X-Request-ID", requestID) - - // Call next handler - next.ServeHTTP(rw, r) - - // Log request - duration := time.Since(start) - log.Info(). - Str("request_id", requestID). - Str("method", r.Method). - Str("path", r.URL.Path). - Str("remote_addr", r.RemoteAddr). - Str("user_agent", r.UserAgent()). - Int("status", rw.statusCode). - Int64("bytes", rw.written). - Dur("duration_ms", duration). - Msg("HTTP request") - }) -} diff --git a/pkg/metadata/interface.go b/pkg/metadata/interface.go index 95aa6da..6aa4175 100644 --- a/pkg/metadata/interface.go +++ b/pkg/metadata/interface.go @@ -68,37 +68,37 @@ type MetadataStore interface { // Package represents package metadata type Package struct { + CachedAt time.Time `json:"cached_at"` + LastAccessed time.Time `json:"last_accessed"` + Metadata map[string]string `json:"metadata"` + ExpiresAt *time.Time `json:"expires_at"` + UpstreamURL string `json:"upstream_url"` + ChecksumMD5 string `json:"checksum_md5"` + ChecksumSHA256 string `json:"checksum_sha256"` ID string `json:"id"` - Registry string `json:"registry"` // npm, pypi, go - Name string `json:"name"` // Package name - Version string `json:"version"` // Package version - StorageKey string `json:"storage_key"` // Key in storage backend - Size int64 `json:"size"` // Package size in bytes - ChecksumMD5 string `json:"checksum_md5"` // MD5 checksum - ChecksumSHA256 string `json:"checksum_sha256"` // SHA256 checksum - UpstreamURL string `json:"upstream_url"` // Original upstream URL - CachedAt time.Time `json:"cached_at"` // When cached - LastAccessed time.Time `json:"last_accessed"` // Last access time - ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never) - DownloadCount int64 `json:"download_count"` // Download counter - Metadata map[string]string `json:"metadata"` // Additional metadata - SecurityScanned bool `json:"security_scanned"` // Has been scanned - RequiresAuth bool `json:"requires_auth"` // Package requires authentication - AuthProvider string `json:"auth_provider"` // Auth provider (github.com, npm.pkg.github.com, etc.) + StorageKey string `json:"storage_key"` + Version string `json:"version"` + Name string `json:"name"` + Registry string `json:"registry"` + AuthProvider string `json:"auth_provider"` + Size int64 `json:"size"` + DownloadCount int64 `json:"download_count"` + SecurityScanned bool `json:"security_scanned"` + RequiresAuth bool `json:"requires_auth"` } // ScanResult represents a security scan result type ScanResult struct { + ScannedAt time.Time `json:"scanned_at"` + Details map[string]interface{} `json:"details"` ID string `json:"id"` Registry string `json:"registry"` PackageName string `json:"package_name"` PackageVersion string `json:"package_version"` - Scanner string `json:"scanner"` // trivy, osv, etc. - ScannedAt time.Time `json:"scanned_at"` - Status ScanStatus `json:"status"` // clean, vulnerable, error - VulnerabilityCount int `json:"vulnerability_count"` + Scanner string `json:"scanner"` + Status ScanStatus `json:"status"` Vulnerabilities []Vulnerability `json:"vulnerabilities"` - Details map[string]interface{} `json:"details"` // Scanner-specific details + VulnerabilityCount int `json:"vulnerability_count"` } // Vulnerability represents a security vulnerability @@ -143,13 +143,13 @@ const ( // Stats represents metadata statistics type Stats struct { + LastUpdated time.Time `json:"last_updated"` Registry string `json:"registry"` TotalPackages int64 `json:"total_packages"` TotalSize int64 `json:"total_size"` TotalDownloads int64 `json:"total_downloads"` ScannedPackages int64 `json:"scanned_packages"` VulnerablePackages int64 `json:"vulnerable_packages"` - LastUpdated time.Time `json:"last_updated"` } // TimeSeriesDataPoint represents a single data point in time-series @@ -198,14 +198,14 @@ type BypassListOptions struct { // ListOptions contains options for listing packages type ListOptions struct { - Registry string // Filter by registry - NamePrefix string // Filter by name prefix - MinSize int64 // Minimum package size - MaxSize int64 // Maximum package size - ScannedOnly bool // Only scanned packages - SinceDate time.Time // Packages cached since date - Limit int // Max results - Offset int // Pagination offset - SortBy string // Sort field (name, size, cached_at, download_count) - SortDesc bool // Sort descending + SinceDate time.Time + Registry string + NamePrefix string + SortBy string + MinSize int64 + MaxSize int64 + Limit int + Offset int + ScannedOnly bool + SortDesc bool } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index f4ed868..d84533e 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -137,21 +137,11 @@ func RecordCacheMiss(handler string) { CacheRequests.WithLabelValues("miss", handler).Inc() } -// RecordCacheError records a cache error -func RecordCacheError(handler string) { - CacheRequests.WithLabelValues("error", handler).Inc() -} - // UpdateCacheSize updates the cache size metric func UpdateCacheSize(backend string, bytes int64) { CacheSizeBytes.WithLabelValues(backend).Set(float64(bytes)) } -// UpdateCacheItems updates the cache items metric -func UpdateCacheItems(handler string, count int64) { - CacheItemsTotal.WithLabelValues(handler).Set(float64(count)) -} - // RecordCacheEviction records a cache eviction func RecordCacheEviction(reason string) { CacheEvictions.WithLabelValues(reason).Inc() @@ -162,26 +152,11 @@ func RecordStorageOperation(backend, operation, status string) { StorageOperations.WithLabelValues(backend, operation, status).Inc() } -// UpdateStorageQuota updates the storage quota metric -func UpdateStorageQuota(project string, bytes int64) { - StorageQuotaBytes.WithLabelValues(project).Set(float64(bytes)) -} - // RecordUpstreamRequest records an upstream request func RecordUpstreamRequest(registry, status string) { UpstreamRequests.WithLabelValues(registry, status).Inc() } -// RecordSecurityScan records a security scan -func RecordSecurityScan(scanner, result string) { - SecurityScans.WithLabelValues(scanner, result).Inc() -} - -// RecordVulnerability records a vulnerability finding -func RecordVulnerability(severity string) { - VulnerabilitiesFound.WithLabelValues(severity).Inc() -} - // UpdateCircuitBreakerState updates the circuit breaker state func UpdateCircuitBreakerState(name string, state int) { CircuitBreakerState.WithLabelValues(name).Set(float64(state)) diff --git a/pkg/network/client.go b/pkg/network/client.go index 2f68b45..929e3ca 100644 --- a/pkg/network/client.go +++ b/pkg/network/client.go @@ -24,23 +24,23 @@ type Client struct { // Config holds client configuration type Config struct { - Timeout time.Duration // Request timeout - MaxRetries int // Max retry attempts - RetryDelay time.Duration // Initial retry delay - RateLimit float64 // Requests per second (0 = unlimited) - RateBurst int // Rate limiter burst - CircuitBreaker CircuitBreakerConfig UserAgent string + CircuitBreaker CircuitBreakerConfig + Timeout time.Duration + MaxRetries int + RetryDelay time.Duration + RateLimit float64 + RateBurst int MaxConnsPerHost int } // RetryConfig holds retry configuration type RetryConfig struct { + FixedDelays []time.Duration MaxAttempts int InitialDelay time.Duration MaxDelay time.Duration Multiplier float64 - FixedDelays []time.Duration // If set, use these delays instead of exponential backoff } // CircuitBreakerConfig holds circuit breaker configuration @@ -63,11 +63,11 @@ const ( // CircuitBreaker implements the circuit breaker pattern type CircuitBreaker struct { + lastFailureTime time.Time config CircuitBreakerConfig state CircuitBreakerState failures int successes int - lastFailureTime time.Time halfOpenCalls int mu sync.RWMutex } diff --git a/pkg/network/client_test.go b/pkg/network/client_test.go index 2860e95..5e0baa8 100644 --- a/pkg/network/client_test.go +++ b/pkg/network/client_test.go @@ -19,14 +19,14 @@ import ( // TestClientGet tests the HTTP client Get method with various scenarios func TestClientGet(t *testing.T) { tests := []struct { - name string serverBehavior func(*testing.T) *httptest.Server - config network.Config headers map[string]string - wantErr bool - errContains string validateBody func(*testing.T, io.ReadCloser) validateStatus func(*testing.T, int) + name string + errContains string + config network.Config + wantErr bool }{ // GOOD: Successful GET request { diff --git a/pkg/prewarming/worker.go b/pkg/prewarming/worker.go index 5f969e9..ca1b284 100644 --- a/pkg/prewarming/worker.go +++ b/pkg/prewarming/worker.go @@ -24,22 +24,22 @@ type Worker struct { cache *cache.Manager analytics *analytics.Engine client *network.Client + stopChan chan struct{} + wg sync.WaitGroup interval time.Duration maxConcurrent int enabled bool - stopChan chan struct{} - wg sync.WaitGroup } // Config holds pre-warming worker configuration type Config struct { - Enabled bool - Interval time.Duration - MaxConcurrent int - TopPackages int CacheManager *cache.Manager Analytics *analytics.Engine NetworkClient *network.Client + Interval time.Duration + MaxConcurrent int + TopPackages int + Enabled bool } // NewWorker creates a new pre-warming worker diff --git a/pkg/proxy/common/base.go b/pkg/proxy/common/base.go deleted file mode 100644 index 28ac205..0000000 --- a/pkg/proxy/common/base.go +++ /dev/null @@ -1,34 +0,0 @@ -package common - -import ( - "github.com/lukaszraczylo/gohoarder/pkg/cache" - "github.com/lukaszraczylo/gohoarder/pkg/network" -) - -// BaseHandler provides common functionality for all proxy handlers -type BaseHandler struct { - Cache *cache.Manager - Client *network.Client - Upstream string - Registry string -} - -// Config holds common proxy configuration -type Config struct { - Upstream string // Upstream registry URL (e.g., registry.npmjs.org) -} - -// GetRegistry returns the registry type -func (h *BaseHandler) GetRegistry() string { - return h.Registry -} - -// NewBaseHandler creates a new base handler with common fields -func NewBaseHandler(cache *cache.Manager, client *network.Client, registry, upstream string) *BaseHandler { - return &BaseHandler{ - Cache: cache, - Client: client, - Upstream: upstream, - Registry: registry, - } -} diff --git a/pkg/proxy/common/common_test.go b/pkg/proxy/common/common_test.go deleted file mode 100644 index 1dcd70c..0000000 --- a/pkg/proxy/common/common_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package common - -import ( - "bytes" - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/lukaszraczylo/gohoarder/pkg/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestNewBaseHandler tests base handler creation -func TestNewBaseHandler(t *testing.T) { - // Use nil for cache and client since we're only testing structure - handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org") - - require.NotNil(t, handler) - assert.Equal(t, "npm", handler.Registry) - assert.Equal(t, "https://registry.npmjs.org", handler.Upstream) - assert.Nil(t, handler.Cache) - assert.Nil(t, handler.Client) -} - -// TestGetRegistry tests registry type retrieval -func TestGetRegistry(t *testing.T) { - tests := []struct { - name string - registry string - }{ - {"npm registry", "npm"}, - {"pypi registry", "pypi"}, - {"go registry", "go"}, - {"custom registry", "custom"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - handler := &BaseHandler{Registry: tt.registry} - assert.Equal(t, tt.registry, handler.GetRegistry()) - }) - } -} - -// TestHandleUpstreamError tests upstream error handling -func TestHandleUpstreamError(t *testing.T) { - tests := []struct { - name string - err error - url string - context string - wantStatus int - wantContain string - }{ - // GOOD: Standard error - { - name: "connection error", - err: errors.New("connection refused"), - url: "https://registry.npmjs.org/react", - context: "package", - wantStatus: http.StatusBadGateway, - wantContain: "Failed to fetch package", - }, - // WRONG: Timeout error - { - name: "timeout error", - err: context.DeadlineExceeded, - url: "https://registry.npmjs.org/lodash", - context: "metadata", - wantStatus: http.StatusBadGateway, - wantContain: "Failed to fetch metadata", - }, - // EDGE: Empty context - { - name: "empty context", - err: errors.New("error"), - url: "https://example.com", - context: "", - wantStatus: http.StatusBadGateway, - wantContain: "Failed to fetch", - }, - // EDGE: Long URL - { - name: "long URL", - err: errors.New("error"), - url: "https://registry.npmjs.org/@scope/very-long-package-name/versions/1.2.3", - context: "package", - wantStatus: http.StatusBadGateway, - wantContain: "Failed to fetch package", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - HandleUpstreamError(w, tt.err, tt.url, tt.context) - - assert.Equal(t, tt.wantStatus, w.Code) - assert.Contains(t, w.Body.String(), tt.wantContain) - }) - } -} - -// TestCheckUpstreamStatus tests upstream status validation -func TestCheckUpstreamStatus(t *testing.T) { - tests := []struct { - name string - statusCode int - body io.ReadCloser - wantErr bool - errContains string - bodyClosed bool - }{ - // GOOD: OK status - { - name: "200 OK", - statusCode: http.StatusOK, - body: io.NopCloser(strings.NewReader("success")), - wantErr: false, - }, - // WRONG: Not found - { - name: "404 Not Found", - statusCode: http.StatusNotFound, - body: io.NopCloser(strings.NewReader("not found")), - wantErr: true, - errContains: "upstream returned status 404", - }, - // WRONG: Server error - { - name: "500 Internal Server Error", - statusCode: http.StatusInternalServerError, - body: io.NopCloser(strings.NewReader("error")), - wantErr: true, - errContains: "upstream returned status 500", - }, - // BAD: Unauthorized - { - name: "401 Unauthorized", - statusCode: http.StatusUnauthorized, - body: io.NopCloser(strings.NewReader("unauthorized")), - wantErr: true, - errContains: "upstream returned status 401", - }, - // EDGE: Nil body - { - name: "nil body with error", - statusCode: http.StatusNotFound, - body: nil, - wantErr: true, - errContains: "upstream returned status 404", - }, - // EDGE: Redirect status - { - name: "302 Found", - statusCode: http.StatusFound, - body: io.NopCloser(strings.NewReader("redirect")), - wantErr: true, - errContains: "upstream returned status 302", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := CheckUpstreamStatus(tt.statusCode, tt.body) - - if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) - } else { - require.NoError(t, err) - } - }) - } -} - -// TestHandleInvalidRequest tests invalid request handling -func TestHandleInvalidRequest(t *testing.T) { - tests := []struct { - name string - registry string - wantStatus int - wantContain string - }{ - { - name: "npm invalid request", - registry: "npm", - wantStatus: http.StatusBadRequest, - wantContain: "Invalid npm request", - }, - { - name: "pypi invalid request", - registry: "pypi", - wantStatus: http.StatusBadRequest, - wantContain: "Invalid pypi request", - }, - { - name: "go invalid request", - registry: "go", - wantStatus: http.StatusBadRequest, - wantContain: "Invalid go request", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - HandleInvalidRequest(w, tt.registry) - - assert.Equal(t, tt.wantStatus, w.Code) - assert.Contains(t, w.Body.String(), tt.wantContain) - }) - } -} - -// TestHandleInternalError tests internal error handling -func TestHandleInternalError(t *testing.T) { - tests := []struct { - name string - err error - context string - wantStatus int - wantContain string - }{ - { - name: "database error", - err: errors.New("database connection failed"), - context: "database", - wantStatus: http.StatusInternalServerError, - wantContain: "Internal error: database", - }, - { - name: "cache error", - err: errors.New("cache write failed"), - context: "cache", - wantStatus: http.StatusInternalServerError, - wantContain: "Internal error: cache", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - HandleInternalError(w, tt.err, tt.context) - - assert.Equal(t, tt.wantStatus, w.Code) - assert.Contains(t, w.Body.String(), tt.wantContain) - }) - } -} - -// Note: FetchFromUpstream tests would require mocking cache.Manager and network.Client -// which requires concrete implementations. Integration tests cover this functionality. - -// TestWriteResponse tests HTTP response writing -func TestWriteResponse(t *testing.T) { - tests := []struct { - name string - data string - contentType string - wantStatus int - wantBody string - wantErr bool - }{ - // GOOD: Write tarball - { - name: "write tarball", - data: "package data here", - contentType: "application/octet-stream", - wantStatus: http.StatusOK, - wantBody: "package data here", - wantErr: false, - }, - // GOOD: Write JSON - { - name: "write JSON metadata", - data: `{"name":"react","version":"18.2.0"}`, - contentType: "application/json", - wantStatus: http.StatusOK, - wantBody: `{"name":"react","version":"18.2.0"}`, - wantErr: false, - }, - // EDGE: Empty data - { - name: "empty data", - data: "", - contentType: "text/plain", - wantStatus: http.StatusOK, - wantBody: "", - wantErr: false, - }, - // EDGE: Large data - { - name: "large data", - data: strings.Repeat("x", 100000), - contentType: "application/octet-stream", - wantStatus: http.StatusOK, - wantBody: strings.Repeat("x", 100000), - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - entry := &cache.CacheEntry{ - Data: io.NopCloser(bytes.NewReader([]byte(tt.data))), - } - - err := WriteResponse(w, entry, tt.contentType) - - if tt.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.contentType, w.Header().Get("Content-Type")) - assert.Equal(t, tt.wantBody, w.Body.String()) - } - }) - } -} - -// TestBaseHandlerFields tests that BaseHandler fields are properly set -func TestBaseHandlerFields(t *testing.T) { - handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org") - - tests := []struct { - name string - field string - expected interface{} - }{ - {"registry field", "registry", "npm"}, - {"upstream field", "upstream", "https://registry.npmjs.org"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - switch tt.field { - case "registry": - assert.Equal(t, tt.expected, handler.Registry) - case "upstream": - assert.Equal(t, tt.expected, handler.Upstream) - } - }) - } -} - -// TestProxyHandlerInterface tests that BaseHandler can be used as ProxyHandler -func TestProxyHandlerInterface(t *testing.T) { - handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org") - - // Verify GetRegistry works - registry := handler.GetRegistry() - assert.Equal(t, "npm", registry) -} - -// TestConcurrentWriteResponse tests that WriteResponse is safe for concurrent use -func TestConcurrentWriteResponse(t *testing.T) { - const numGoroutines = 10 - - errs := make(chan error, numGoroutines) - for i := 0; i < numGoroutines; i++ { - go func(n int) { - w := httptest.NewRecorder() - data := strings.Repeat("x", 1000) - entry := &cache.CacheEntry{ - Data: io.NopCloser(bytes.NewReader([]byte(data))), - } - - err := WriteResponse(w, entry, "text/plain") - errs <- err - }(i) - } - - // Collect results - for i := 0; i < numGoroutines; i++ { - err := <-errs - assert.NoError(t, err) - } -} diff --git a/pkg/proxy/common/errors.go b/pkg/proxy/common/errors.go deleted file mode 100644 index 50f6b66..0000000 --- a/pkg/proxy/common/errors.go +++ /dev/null @@ -1,48 +0,0 @@ -package common - -import ( - "fmt" - "io" - "net/http" - - "github.com/rs/zerolog/log" -) - -// HandleUpstreamError logs an error and sends an HTTP 502 Bad Gateway response -// This is the common pattern used across all proxy handlers when upstream fetch fails -func HandleUpstreamError(w http.ResponseWriter, err error, url, context string) { - log.Error(). - Err(err). - Str("url", url). - Str("context", context). - Msg("Failed to fetch from upstream") - - http.Error(w, fmt.Sprintf("Failed to fetch %s", context), http.StatusBadGateway) -} - -// CheckUpstreamStatus validates HTTP status code from upstream -// Returns error if status is not OK, closing body if needed -func CheckUpstreamStatus(statusCode int, body io.ReadCloser) error { - if statusCode != http.StatusOK { - if body != nil { - body.Close() // #nosec G104 -- Cleanup, error not critical - } - return fmt.Errorf("upstream returned status %d", statusCode) - } - return nil -} - -// HandleInvalidRequest sends a 400 Bad Request response for invalid proxy requests -func HandleInvalidRequest(w http.ResponseWriter, registry string) { - http.Error(w, fmt.Sprintf("Invalid %s request", registry), http.StatusBadRequest) -} - -// HandleInternalError logs an internal error and sends 500 response -func HandleInternalError(w http.ResponseWriter, err error, context string) { - log.Error(). - Err(err). - Str("context", context). - Msg("Internal error processing request") - - http.Error(w, fmt.Sprintf("Internal error: %s", context), http.StatusInternalServerError) -} diff --git a/pkg/proxy/common/http.go b/pkg/proxy/common/http.go deleted file mode 100644 index 4d02349..0000000 --- a/pkg/proxy/common/http.go +++ /dev/null @@ -1,58 +0,0 @@ -package common - -import ( - "context" - "io" - "net/http" - - "github.com/lukaszraczylo/gohoarder/pkg/cache" - "github.com/lukaszraczylo/gohoarder/pkg/network" - "github.com/rs/zerolog/log" -) - -// FetchFromUpstream is a common helper to fetch content from upstream with caching -// This encapsulates the common pattern of: cache.Get -> network.Get -> error handling -func FetchFromUpstream( - ctx context.Context, - cacheManager *cache.Manager, - client *network.Client, - registry, name, version, upstreamURL string, -) (*cache.CacheEntry, error) { - entry, err := cacheManager.Get(ctx, registry, name, version, func(ctx context.Context) (io.ReadCloser, string, error) { - body, statusCode, err := client.Get(ctx, upstreamURL, nil) - if err != nil { - return nil, "", err - } - if err := CheckUpstreamStatus(statusCode, body); err != nil { - return nil, "", err - } - return body, upstreamURL, nil - }) - - if err != nil { - log.Error(). - Err(err). - Str("url", upstreamURL). - Str("registry", registry). - Str("name", name). - Str("version", version). - Msg("Failed to fetch package from upstream") - return nil, err - } - - return entry, nil -} - -// WriteResponse writes the cache entry data to the HTTP response writer -// Sets appropriate content type and handles errors -func WriteResponse(w http.ResponseWriter, entry *cache.CacheEntry, contentType string) error { - defer entry.Data.Close() // #nosec G104 -- Cleanup, error not critical - - w.Header().Set("Content-Type", contentType) - if _, err := io.Copy(w, entry.Data); err != nil { - log.Error().Err(err).Msg("Failed to write response") - return err - } - - return nil -} diff --git a/pkg/proxy/common/interface.go b/pkg/proxy/common/interface.go deleted file mode 100644 index eb08a39..0000000 --- a/pkg/proxy/common/interface.go +++ /dev/null @@ -1,29 +0,0 @@ -package common - -import ( - "context" - "net/http" - "time" -) - -// ProxyHandler defines the common interface for all registry proxies -type ProxyHandler interface { - http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request) - - // GetRegistry returns the registry type (npm, pypi, go) - GetRegistry() string - - // Health checks if the proxy can reach its upstream - Health(ctx context.Context) error -} - -// Stats represents proxy statistics -type Stats struct { - Registry string - TotalRequests int64 - CacheHits int64 - CacheMisses int64 - UpstreamErrors int64 - AvgResponseTime time.Duration - LastUpdated time.Time -} diff --git a/pkg/proxy/goproxy/goproxy.go b/pkg/proxy/goproxy/goproxy.go index 7d629c7..8426007 100644 --- a/pkg/proxy/goproxy/goproxy.go +++ b/pkg/proxy/goproxy/goproxy.go @@ -20,21 +20,21 @@ import ( type Handler struct { cache *cache.Manager client *network.Client - upstream string - sumDBURL string credExtractor *auth.CredentialExtractor credHasher *auth.CredentialHasher credValidator *auth.GoValidator validationCache *auth.ValidationCache gitFetcher *vcs.GitFetcher moduleBuilder *vcs.ModuleBuilder + upstream string + sumDBURL string } // Config holds Go proxy configuration type Config struct { - Upstream string // Upstream Go proxy (e.g., proxy.golang.org) - SumDBURL string // Checksum database URL - CredStore *vcs.CredentialStore // Optional credential store for git access + CredStore *vcs.CredentialStore + Upstream string + SumDBURL string } // New creates a new Go proxy handler diff --git a/pkg/proxy/npm/npm.go b/pkg/proxy/npm/npm.go index 18ef9af..a375a0d 100644 --- a/pkg/proxy/npm/npm.go +++ b/pkg/proxy/npm/npm.go @@ -21,11 +21,11 @@ import ( type Handler struct { cache *cache.Manager client *network.Client - upstream string credExtractor *auth.CredentialExtractor credHasher *auth.CredentialHasher credValidator *auth.NPMValidator validationCache *auth.ValidationCache + upstream string } // Config holds NPM proxy configuration diff --git a/pkg/proxy/pypi/pypi.go b/pkg/proxy/pypi/pypi.go index 1f45d4b..4a2d06b 100644 --- a/pkg/proxy/pypi/pypi.go +++ b/pkg/proxy/pypi/pypi.go @@ -21,11 +21,11 @@ import ( type Handler struct { cache *cache.Manager client *network.Client - upstream string credExtractor *auth.CredentialExtractor credHasher *auth.CredentialHasher credValidator *auth.PyPIValidator validationCache *auth.ValidationCache + upstream string } // Config holds PyPI proxy configuration diff --git a/pkg/scanner/ghsa/ghsa.go b/pkg/scanner/ghsa/ghsa.go index a8099c5..2d10054 100644 --- a/pkg/scanner/ghsa/ghsa.go +++ b/pkg/scanner/ghsa/ghsa.go @@ -20,8 +20,8 @@ const ScannerName = "github-advisory-database" // Scanner implements the GitHub Advisory Database vulnerability scanner type Scanner struct { - config config.GHSAConfig httpClient *http.Client + config config.GHSAConfig } // New creates a new GitHub Advisory Database scanner @@ -257,10 +257,10 @@ type GHSAAdvisory struct { Description string `json:"description"` Severity string `json:"severity"` HTMLURL string `json:"html_url"` - References []GHSAReference `json:"references"` - Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"` PublishedAt string `json:"published_at"` UpdatedAt string `json:"updated_at"` + References []GHSAReference `json:"references"` + Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"` } type GHSAReference struct { @@ -268,9 +268,9 @@ type GHSAReference struct { } type GHSAVulnerability struct { + FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"` Package GHSAPackage `json:"package"` VulnerableVersions string `json:"vulnerable_version_range"` - FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"` } type GHSAPackage struct { diff --git a/pkg/scanner/grype/grype.go b/pkg/scanner/grype/grype.go index aa6c41f..40b18db 100644 --- a/pkg/scanner/grype/grype.go +++ b/pkg/scanner/grype/grype.go @@ -153,9 +153,9 @@ func (s *Scanner) convertGrypeResult(grypeResult *GrypeResult, registry, package // GrypeResult represents Grype JSON output structure type GrypeResult struct { - Matches []GrypeMatch `json:"matches"` - Descriptor GrypeDescriptor `json:"descriptor"` Source GrypeSource `json:"source"` + Descriptor GrypeDescriptor `json:"descriptor"` + Matches []GrypeMatch `json:"matches"` } type GrypeDescriptor struct { @@ -164,13 +164,13 @@ type GrypeDescriptor struct { } type GrypeSource struct { - Type string `json:"type"` Target map[string]interface{} `json:"target"` + Type string `json:"type"` } type GrypeMatch struct { - Vulnerability GrypeVulnerability `json:"vulnerability"` Artifact GrypeArtifact `json:"artifact"` + Vulnerability GrypeVulnerability `json:"vulnerability"` } type GrypeVulnerability struct { diff --git a/pkg/scanner/npmaudit/npmaudit.go b/pkg/scanner/npmaudit/npmaudit.go index bec736a..f47862d 100644 --- a/pkg/scanner/npmaudit/npmaudit.go +++ b/pkg/scanner/npmaudit/npmaudit.go @@ -199,9 +199,9 @@ func (s *Scanner) convertResult(auditResult *NpmAuditResult, registry, packageNa // NpmAuditResult represents npm audit JSON output type NpmAuditResult struct { - AuditReportVersion int `json:"auditReportVersion"` Vulnerabilities map[string]NpmVulnerability `json:"vulnerabilities"` Metadata NpmAuditMetadata `json:"metadata"` + AuditReportVersion int `json:"auditReportVersion"` } type NpmVulnerability struct { diff --git a/pkg/scanner/osv/osv.go b/pkg/scanner/osv/osv.go index e1650ee..aeae309 100644 --- a/pkg/scanner/osv/osv.go +++ b/pkg/scanner/osv/osv.go @@ -25,8 +25,8 @@ const ( // Scanner implements the Scanner interface using OSV.dev API type Scanner struct { - config config.OSVConfig httpClient *http.Client + config config.OSVConfig } // OSVRequest represents the request structure for OSV API @@ -48,13 +48,13 @@ type OSVResponse struct { // OSVVulnerability represents a vulnerability in OSV format type OSVVulnerability struct { + DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"` ID string `json:"id"` Summary string `json:"summary"` Details string `json:"details"` Severity []OSVSeverity `json:"severity,omitempty"` References []OSVReference `json:"references,omitempty"` Affected []OSVAffected `json:"affected"` - DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"` } // OSVSeverity represents severity information @@ -71,11 +71,11 @@ type OSVReference struct { // OSVAffected represents affected package versions type OSVAffected struct { + DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"` + EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"` Package PackageInfo `json:"package"` Ranges []OSVRange `json:"ranges,omitempty"` Versions []string `json:"versions,omitempty"` - DatabaseSpecific map[string]interface{} `json:"database_specific,omitempty"` - EcosystemSpecific map[string]interface{} `json:"ecosystem_specific,omitempty"` } // OSVRange represents version ranges diff --git a/pkg/scanner/rescanner.go b/pkg/scanner/rescanner.go index 958696c..680472c 100644 --- a/pkg/scanner/rescanner.go +++ b/pkg/scanner/rescanner.go @@ -11,11 +11,11 @@ import ( // RescanWorker handles periodic re-scanning of cached packages type RescanWorker struct { - manager *Manager metadataStore metadata.MetadataStore storage storage.StorageBackend - interval time.Duration + manager *Manager stopCh chan struct{} + interval time.Duration } // NewRescanWorker creates a new rescan worker diff --git a/pkg/scanner/scanner.go b/pkg/scanner/scanner.go index 5afc9aa..9c384f7 100644 --- a/pkg/scanner/scanner.go +++ b/pkg/scanner/scanner.go @@ -36,10 +36,10 @@ type DatabaseUpdater interface { // Manager manages multiple security scanners type Manager struct { + metadataStore metadata.MetadataStore + config config.SecurityConfig scanners []Scanner enabled bool - config config.SecurityConfig - metadataStore metadata.MetadataStore } // New creates a new scanner manager with configured scanners diff --git a/pkg/scanner/trivy/trivy.go b/pkg/scanner/trivy/trivy.go index ae120f9..3ddb3d2 100644 --- a/pkg/scanner/trivy/trivy.go +++ b/pkg/scanner/trivy/trivy.go @@ -25,18 +25,18 @@ type Scanner struct { // TrivyResult represents Trivy JSON output structure type TrivyResult struct { - SchemaVersion int `json:"SchemaVersion"` + Metadata TrivyMetadata `json:"Metadata"` ArtifactName string `json:"ArtifactName"` ArtifactType string `json:"ArtifactType"` - Metadata TrivyMetadata `json:"Metadata"` Results []TrivyVulnResult `json:"Results"` + SchemaVersion int `json:"SchemaVersion"` } type TrivyMetadata struct { OS *TrivyOS `json:"OS,omitempty"` + ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"` RepoTags []string `json:"RepoTags,omitempty"` RepoDigests []string `json:"RepoDigests,omitempty"` - ImageConfig *TrivyImageConfig `json:"ImageConfig,omitempty"` } type TrivyOS struct { @@ -64,8 +64,8 @@ type TrivyVulnerability struct { Severity string `json:"Severity"` Title string `json:"Title"` Description string `json:"Description"` - References []string `json:"References"` PrimaryURL string `json:"PrimaryURL"` + References []string `json:"References"` } // New creates a new Trivy scanner diff --git a/pkg/server/server.go b/pkg/server/server.go deleted file mode 100644 index 48bc4e7..0000000 --- a/pkg/server/server.go +++ /dev/null @@ -1,130 +0,0 @@ -package server - -import ( - "fmt" - "net/http" - - "github.com/lukaszraczylo/gohoarder/pkg/config" - "github.com/lukaszraczylo/gohoarder/pkg/health" - "github.com/lukaszraczylo/gohoarder/pkg/logger" - "github.com/lukaszraczylo/gohoarder/pkg/metrics" -) - -// Server wraps http.Server with configuration -type Server struct { - *http.Server - config *config.Config - healthChecker *health.Checker -} - -// New creates a new HTTP server -func New(cfg *config.Config, healthChecker *health.Checker) (*Server, error) { - mux := http.NewServeMux() - - // Register routes - registerRoutes(mux, cfg, healthChecker) - - srv := &http.Server{ - Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), - Handler: logger.Middleware(mux), - ReadTimeout: cfg.Server.ReadTimeout, - WriteTimeout: cfg.Server.WriteTimeout, - IdleTimeout: cfg.Server.IdleTimeout, - } - - return &Server{ - Server: srv, - config: cfg, - healthChecker: healthChecker, - }, nil -} - -// registerRoutes registers all HTTP routes -func registerRoutes(mux *http.ServeMux, cfg *config.Config, healthChecker *health.Checker) { - // Health endpoints - mux.HandleFunc("/health", healthChecker.HealthHandler()) - mux.HandleFunc("/health/ready", healthChecker.ReadyHandler()) - - // Metrics endpoint - mux.Handle("/metrics", metrics.Handler()) - - // API endpoints - mux.HandleFunc("/api/v1/info", handleInfo(cfg)) - - // Package manager proxy endpoints (placeholders for now) - if cfg.Handlers.Go.Enabled { - mux.HandleFunc("/go/", handleGoProxy()) - } - if cfg.Handlers.NPM.Enabled { - mux.HandleFunc("/npm/", handleNPMProxy()) - } - if cfg.Handlers.PyPI.Enabled { - mux.HandleFunc("/pypi/", handlePyPIProxy()) - } - - // Root endpoint - mux.HandleFunc("/", handleRoot()) -} - -// handleInfo returns server information -func handleInfo(cfg *config.Config) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - info := map[string]interface{}{ - "name": "GoHoarder", - "version": "dev", - "handlers": map[string]bool{ - "go": cfg.Handlers.Go.Enabled, - "npm": cfg.Handlers.NPM.Enabled, - "pypi": cfg.Handlers.PyPI.Enabled, - }, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"success":true,"data":%v}`, toJSON(info)) - } -} - -// handleGoProxy handles Go module proxy requests -func handleGoProxy() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // TODO: Implement Go proxy handler - http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"Go proxy not yet implemented"}}`, http.StatusNotImplemented) - } -} - -// handleNPMProxy handles NPM registry requests -func handleNPMProxy() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // TODO: Implement NPM proxy handler - http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"NPM proxy not yet implemented"}}`, http.StatusNotImplemented) - } -} - -// handlePyPIProxy handles PyPI requests -func handlePyPIProxy() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // TODO: Implement PyPI proxy handler - http.Error(w, `{"success":false,"error":{"code":"NOT_IMPLEMENTED","message":"PyPI proxy not yet implemented"}}`, http.StatusNotImplemented) - } -} - -// handleRoot handles root path -func handleRoot() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" { - http.NotFound(w, r) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, `{"success":true,"data":{"message":"GoHoarder - Universal Package Cache Proxy","docs":"https://github.com/lukaszraczylo/gohoarder"}}`) - } -} - -// toJSON is a simple JSON encoder (replace with proper implementation) -func toJSON(v interface{}) string { - // Simplified for now - proper implementation would use goccy/go-json - return fmt.Sprintf("%v", v) -} diff --git a/pkg/storage/filesystem/filesystem_test.go b/pkg/storage/filesystem/filesystem_test.go index 0988cb8..4813de5 100644 --- a/pkg/storage/filesystem/filesystem_test.go +++ b/pkg/storage/filesystem/filesystem_test.go @@ -17,8 +17,8 @@ import ( type FilesystemStorageTestSuite struct { suite.Suite - tempDir string fs *FilesystemStorage + tempDir string } func (s *FilesystemStorageTestSuite) SetupTest() { @@ -46,12 +46,12 @@ func TestFilesystemStorageTestSuite(t *testing.T) { // Test Put operation func (s *FilesystemStorageTestSuite) TestPut() { tests := []struct { + opts *storage.PutOptions + errorCheck func(error) bool name string key string data string - opts *storage.PutOptions expectError bool - errorCheck func(error) bool }{ { name: "successful put", @@ -122,8 +122,8 @@ func (s *FilesystemStorageTestSuite) TestGet() { tests := []struct { name string key string - expectError bool expectData string + expectError bool }{ { name: "get existing file", @@ -258,11 +258,11 @@ func (s *FilesystemStorageTestSuite) TestList() { } tests := []struct { + opts *storage.ListOptions name string prefix string - opts *storage.ListOptions - expectedCount int expectedKeys []string + expectedCount int }{ { name: "list all npm packages", @@ -422,8 +422,8 @@ func (s *FilesystemStorageTestSuite) TestContextCancellation() { cancel() tests := []struct { - name string fn func() error + name string }{ { name: "Get with cancelled context", @@ -679,8 +679,8 @@ func (s *FilesystemStorageTestSuite) TestChecksumValidation() { correctMD5 := "7dd7323e8ce3e087972f93d3711ef62b" tests := []struct { - name string opts *storage.PutOptions + name string expectError bool }{ { diff --git a/pkg/storage/interface.go b/pkg/storage/interface.go index d5d0230..10a2b35 100644 --- a/pkg/storage/interface.go +++ b/pkg/storage/interface.go @@ -52,21 +52,21 @@ type ListOptions struct { // StorageObject represents a stored object type StorageObject struct { - Key string - Size int64 Modified time.Time + Key string ETag string + Size int64 } // StorageInfo contains detailed object information type StorageInfo struct { - Key string - Size int64 Modified time.Time - ETag string - ContentType string Metadata map[string]string Checksums *Checksums + Key string + ETag string + ContentType string + Size int64 } // Checksums contains file checksums diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index 3274a90..249b56a 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -3,14 +3,10 @@ package s3 import ( "bytes" "context" - "crypto/md5" // #nosec G501 -- MD5 used for S3 Content-MD5 header, not cryptographic security - "crypto/sha256" - "encoding/hex" stderrors "errors" "fmt" "io" "strings" - "sync" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -18,261 +14,210 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/lukaszraczylo/gohoarder/pkg/errors" - "github.com/lukaszraczylo/gohoarder/pkg/metrics" "github.com/lukaszraczylo/gohoarder/pkg/storage" "github.com/rs/zerolog/log" ) -// S3Storage implements storage.StorageBackend for AWS S3 -type S3Storage struct { - client *s3.Client - bucket string - prefix string - quota int64 - mu sync.RWMutex - used int64 -} - -// Config holds S3 configuration +// Config holds S3 storage configuration type Config struct { - Bucket string Region string - Endpoint string // For S3-compatible services (MinIO, etc.) + Bucket string + Prefix string AccessKeyID string SecretAccessKey string - Prefix string // Optional prefix for all keys - Quota int64 // Quota in bytes (0 = unlimited) - ForcePathStyle bool // For S3-compatible services + Endpoint string // Optional: for S3-compatible services like MinIO + ForcePathStyle bool // Optional: for S3-compatible services + MaxSizeBytes int64 +} + +// S3Storage implements storage.StorageBackend using AWS S3 +type S3Storage struct { + client *s3.Client + bucket string + prefix string + maxSizeBytes int64 } // New creates a new S3 storage backend -func New(ctx context.Context, cfg Config) (*S3Storage, error) { +func New(cfg Config) (*S3Storage, error) { if cfg.Bucket == "" { - return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 bucket is required") + return nil, fmt.Errorf("S3 bucket is required") } if cfg.Region == "" { - return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 region is required") + cfg.Region = "us-east-1" // Default region } // Build AWS config - var awsCfg aws.Config + var awsConfig aws.Config var err error + // Build config options + configOpts := []func(*config.LoadOptions) error{ + config.WithRegion(cfg.Region), + } + + // Add credentials if provided if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" { - // Use static credentials - awsCfg, err = config.LoadDefaultConfig(ctx, - config.WithRegion(cfg.Region), - config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + configOpts = append(configOpts, config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider( cfg.AccessKeyID, cfg.SecretAccessKey, "", - )), - ) - } else { - // Use default credential chain - awsCfg, err = config.LoadDefaultConfig(ctx, - config.WithRegion(cfg.Region), - ) + ), + )) } + awsConfig, err = config.LoadDefaultConfig(context.Background(), configOpts...) if err != nil { - return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to load AWS config") + return nil, fmt.Errorf("failed to load AWS config: %w", err) } - // Create S3 client - var s3Options []func(*s3.Options) - - if cfg.Endpoint != "" { - s3Options = append(s3Options, func(o *s3.Options) { + // Create S3 client with service-specific options + client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { + // Use custom endpoint if provided (for MinIO, S3-compatible services, etc.) + if cfg.Endpoint != "" { o.BaseEndpoint = aws.String(cfg.Endpoint) - o.UsePathStyle = cfg.ForcePathStyle - }) + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + }) + + storage := &S3Storage{ + client: client, + bucket: cfg.Bucket, + prefix: strings.TrimSuffix(cfg.Prefix, "/"), + maxSizeBytes: cfg.MaxSizeBytes, } - client := s3.NewFromConfig(awsCfg, s3Options...) + log.Info(). + Str("bucket", cfg.Bucket). + Str("region", cfg.Region). + Str("prefix", cfg.Prefix). + Msg("S3 storage initialized") - s3Storage := &S3Storage{ - client: client, - bucket: cfg.Bucket, - prefix: strings.TrimSuffix(cfg.Prefix, "/"), - quota: cfg.Quota, - } - - // Calculate initial usage - if err := s3Storage.calculateUsage(ctx); err != nil { - log.Warn().Err(err).Msg("Failed to calculate initial S3 storage usage") - } - - return s3Storage, nil + return storage, nil } -// Get retrieves a file from S3 +// Get retrieves data from S3 func (s *S3Storage) Get(ctx context.Context, key string) (io.ReadCloser, error) { - s3Key := s.buildKey(key) + fullKey := s.buildKey(key) - input := &s3.GetObjectInput{ + log.Debug().Str("key", fullKey).Msg("Getting object from S3") + + result, err := s.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(s3Key), - } + Key: aws.String(fullKey), + }) - result, err := s.client.GetObject(ctx, input) if err != nil { if isNotFoundError(err) { - metrics.RecordStorageOperation("s3", "get", "not_found") - return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key)) + return nil, errors.NotFound(fmt.Sprintf("S3 object not found: %s", key)) } - metrics.RecordStorageOperation("s3", "get", "error") return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get object from S3") } - metrics.RecordStorageOperation("s3", "get", "success") return result.Body, nil } -// Put stores a file in S3 +// Put stores data in S3 func (s *S3Storage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error { - s3Key := s.buildKey(key) + fullKey := s.buildKey(key) - // Read data into buffer to calculate checksums and size - var buf bytes.Buffer - md5Hash := md5.New() // #nosec G401 -- MD5 used for S3 integrity check, not cryptographic security - sha256Hash := sha256.New() - multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash) - - written, err := io.Copy(multiWriter, data) + // Read data into buffer to get size + buf := new(bytes.Buffer) + size, err := io.Copy(buf, data) if err != nil { - metrics.RecordStorageOperation("s3", "put", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data") + return fmt.Errorf("failed to read data: %w", err) } - // Check quota before upload - if s.quota > 0 { - s.mu.RLock() - used := s.used - s.mu.RUnlock() + log.Debug(). + Str("key", fullKey). + Int64("size", size). + Msg("Putting object to S3") - if used+written > s.quota { - metrics.RecordStorageOperation("s3", "put", "quota_exceeded") - return errors.QuotaExceeded(s.quota) + // Check quota if set + if s.maxSizeBytes > 0 { + currentUsage, err := s.calculateUsage(ctx) + if err != nil { + log.Warn().Err(err).Msg("Failed to calculate current usage, skipping quota check") + } else if currentUsage+size > s.maxSizeBytes { + return errors.QuotaExceeded(s.maxSizeBytes) } } - // Verify checksums if provided - if opts != nil { - md5Sum := hex.EncodeToString(md5Hash.Sum(nil)) - sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil)) - - if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum { - metrics.RecordStorageOperation("s3", "put", "checksum_error") - return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch") - } - - if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum { - metrics.RecordStorageOperation("s3", "put", "checksum_error") - return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch") - } - } - - // Prepare metadata - metadata := make(map[string]string) + // Convert metadata to S3 metadata format + s3Metadata := make(map[string]string) if opts != nil && opts.Metadata != nil { - metadata = opts.Metadata - } - - // Build put input - input := &s3.PutObjectInput{ - Bucket: aws.String(s.bucket), - Key: aws.String(s3Key), - Body: bytes.NewReader(buf.Bytes()), - Metadata: metadata, - } - - if opts != nil && opts.ContentType != "" { - input.ContentType = aws.String(opts.ContentType) + for k, v := range opts.Metadata { + s3Metadata[k] = v + } } // Upload to S3 - _, err = s.client.PutObject(ctx, input) + _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(fullKey), + Body: bytes.NewReader(buf.Bytes()), + Metadata: s3Metadata, + }) + if err != nil { - metrics.RecordStorageOperation("s3", "put", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to upload to S3") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to put object to S3") } - // Update usage - s.mu.Lock() - s.used += written - currentUsed := s.used - s.mu.Unlock() - - metrics.RecordStorageOperation("s3", "put", "success") - metrics.UpdateCacheSize("s3", currentUsed) return nil } -// Delete removes a file from S3 +// Delete removes data from S3 func (s *S3Storage) Delete(ctx context.Context, key string) error { - s3Key := s.buildKey(key) + fullKey := s.buildKey(key) - // Get size before deletion for quota tracking - statInfo, err := s.Stat(ctx, key) - if err != nil { - return err - } + log.Debug().Str("key", fullKey).Msg("Deleting object from S3") - input := &s3.DeleteObjectInput{ + _, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(s3Key), - } + Key: aws.String(fullKey), + }) - _, err = s.client.DeleteObject(ctx, input) if err != nil { - metrics.RecordStorageOperation("s3", "delete", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete from S3") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete object from S3") } - // Update usage - s.mu.Lock() - s.used -= statInfo.Size - currentUsed := s.used - s.mu.Unlock() - - metrics.RecordStorageOperation("s3", "delete", "success") - metrics.UpdateCacheSize("s3", currentUsed) return nil } -// Exists checks if a file exists in S3 +// Exists checks if data exists in S3 func (s *S3Storage) Exists(ctx context.Context, key string) (bool, error) { - s3Key := s.buildKey(key) + fullKey := s.buildKey(key) - input := &s3.HeadObjectInput{ + _, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(s3Key), - } + Key: aws.String(fullKey), + }) - _, err := s.client.HeadObject(ctx, input) if err != nil { if isNotFoundError(err) { return false, nil } - return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence in S3") + return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check object existence in S3") } return true, nil } -// List lists files with prefix in S3 +// List returns a list of objects with the given prefix func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) { - s3Prefix := s.buildKey(prefix) + fullPrefix := s.buildKey(prefix) - input := &s3.ListObjectsV2Input{ - Bucket: aws.String(s.bucket), - Prefix: aws.String(s3Prefix), - } + log.Debug().Str("prefix", fullPrefix).Msg("Listing objects in S3") var objects []storage.StorageObject - paginator := s3.NewListObjectsV2Paginator(s.client, input) + paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{ + Bucket: aws.String(s.bucket), + Prefix: aws.String(fullPrefix), + }) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) @@ -281,56 +226,58 @@ func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListO } for _, obj := range page.Contents { - key := s.stripPrefix(*obj.Key) - objects = append(objects, storage.StorageObject{ - Key: key, - Size: *obj.Size, - Modified: *obj.LastModified, - ETag: strings.Trim(*obj.ETag, "\""), - }) - } - } + if obj.Key != nil { + // Strip prefix from key + key := s.stripPrefix(*obj.Key) - // Apply pagination if requested - if opts != nil { - start := opts.Offset - end := len(objects) - if opts.MaxResults > 0 && start+opts.MaxResults < end { - end = start + opts.MaxResults - } - if start < len(objects) { - objects = objects[start:end] - } else { - objects = []storage.StorageObject{} + object := storage.StorageObject{ + Key: key, + Size: aws.ToInt64(obj.Size), + } + + if obj.LastModified != nil { + object.Modified = *obj.LastModified + } + + if obj.ETag != nil { + object.ETag = *obj.ETag + } + + objects = append(objects, object) + } } } return objects, nil } -// Stat gets file metadata from S3 +// Stat returns metadata about stored data func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) { - s3Key := s.buildKey(key) + fullKey := s.buildKey(key) - input := &s3.HeadObjectInput{ + result, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), - Key: aws.String(s3Key), - } + Key: aws.String(fullKey), + }) - result, err := s.client.HeadObject(ctx, input) if err != nil { if isNotFoundError(err) { - return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key)) + return nil, errors.NotFound(fmt.Sprintf("S3 object not found: %s", key)) } return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat object in S3") } info := &storage.StorageInfo{ - Key: key, - Size: *result.ContentLength, - Modified: *result.LastModified, - ETag: strings.Trim(*result.ETag, "\""), - Metadata: result.Metadata, + Key: key, + Size: aws.ToInt64(result.ContentLength), + } + + if result.LastModified != nil { + info.Modified = *result.LastModified + } + + if result.ETag != nil { + info.ETag = *result.ETag } if result.ContentType != nil { @@ -340,33 +287,27 @@ func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, return info, nil } -// GetQuota returns quota information +// GetQuota returns current usage and quota information func (s *S3Storage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) { - s.mu.RLock() - used := s.used - s.mu.RUnlock() - - available := s.quota - used - if available < 0 { - available = 0 + usage, err := s.calculateUsage(ctx) + if err != nil { + return nil, err } return &storage.QuotaInfo{ - Used: used, - Available: available, - Limit: s.quota, + Used: usage, + Limit: s.maxSizeBytes, }, nil } -// Health checks S3 health +// Health checks if the S3 backend is healthy func (s *S3Storage) Health(ctx context.Context) error { - // Try to list bucket to verify connectivity - input := &s3.ListObjectsV2Input{ + // Try to list objects (lightweight operation) + _, err := s.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucket), MaxKeys: aws.Int32(1), - } + }) - _, err := s.client.ListObjectsV2(ctx, input) if err != nil { return errors.Wrap(err, errors.ErrCodeStorageFailure, "S3 health check failed") } @@ -374,60 +315,51 @@ func (s *S3Storage) Health(ctx context.Context) error { return nil } -// Close closes the storage backend +// Close closes the S3 storage backend func (s *S3Storage) Close() error { - // No cleanup needed for S3 client + log.Info().Msg("S3 storage closed") return nil } -// buildKey builds the full S3 key with prefix +// buildKey constructs the full S3 key with prefix func (s *S3Storage) buildKey(key string) string { - key = strings.TrimPrefix(key, "/") - if s.prefix != "" { - return s.prefix + "/" + key + if s.prefix == "" { + return key } - return key + return s.prefix + "/" + key } -// stripPrefix removes the configured prefix from an S3 key -func (s *S3Storage) stripPrefix(s3Key string) string { - if s.prefix != "" { - return strings.TrimPrefix(s3Key, s.prefix+"/") +// stripPrefix removes the prefix from an S3 key +func (s *S3Storage) stripPrefix(key string) string { + if s.prefix == "" { + return key } - return s3Key + return strings.TrimPrefix(key, s.prefix+"/") } -// calculateUsage calculates current S3 storage usage -func (s *S3Storage) calculateUsage(ctx context.Context) error { - var total int64 +// calculateUsage calculates total storage usage +func (s *S3Storage) calculateUsage(ctx context.Context) (int64, error) { + var totalSize int64 - input := &s3.ListObjectsV2Input{ + paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucket), - } - - if s.prefix != "" { - input.Prefix = aws.String(s.prefix + "/") - } - - paginator := s3.NewListObjectsV2Paginator(s.client, input) + Prefix: aws.String(s.prefix), + }) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) if err != nil { - return err + return 0, fmt.Errorf("failed to calculate usage: %w", err) } for _, obj := range page.Contents { - total += *obj.Size + if obj.Size != nil { + totalSize += aws.ToInt64(obj.Size) + } } } - s.mu.Lock() - s.used = total - s.mu.Unlock() - - metrics.UpdateCacheSize("s3", total) - return nil + return totalSize, nil } // isNotFoundError checks if an error is a "not found" error @@ -436,8 +368,21 @@ func isNotFoundError(err error) bool { return false } + // Check for specific S3 error types var notFound *types.NotFound var noSuchKey *types.NoSuchKey - return stderrors.As(err, ¬Found) || stderrors.As(err, &noSuchKey) + // Use errors.As to check for wrapped errors + if ok := stderrors.As(err, ¬Found); ok { + return true + } + if ok := stderrors.As(err, &noSuchKey); ok { + return true + } + + // Check error message as fallback + errMsg := err.Error() + return strings.Contains(errMsg, "NoSuchKey") || + strings.Contains(errMsg, "NotFound") || + strings.Contains(errMsg, "404") } diff --git a/pkg/storage/s3/s3_test.go b/pkg/storage/s3/s3_test.go new file mode 100644 index 0000000..3450d61 --- /dev/null +++ b/pkg/storage/s3/s3_test.go @@ -0,0 +1,271 @@ +package s3 + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type S3StorageTestSuite struct { + suite.Suite +} + +func TestS3StorageTestSuite(t *testing.T) { + suite.Run(t, new(S3StorageTestSuite)) +} + +func (s *S3StorageTestSuite) TestNewS3Storage() { + tests := []struct { + name string + config Config + expectError bool + errorMsg string + }{ + { + name: "valid config with credentials", + config: Config{ + Region: "us-east-1", + Bucket: "test-bucket", + Prefix: "packages/", + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + MaxSizeBytes: 1024 * 1024, + }, + expectError: false, + }, + { + name: "valid config with custom endpoint", + config: Config{ + Region: "us-east-1", + Bucket: "test-bucket", + Endpoint: "https://minio.example.com", + AccessKeyID: "minioadmin", + SecretAccessKey: "minioadmin", + ForcePathStyle: true, + }, + expectError: false, + }, + { + name: "valid config with default region", + config: Config{ + Bucket: "test-bucket", + AccessKeyID: "test", + SecretAccessKey: "test", + }, + expectError: false, + }, + { + name: "missing bucket", + config: Config{ + Region: "us-east-1", + AccessKeyID: "test", + SecretAccessKey: "test", + }, + expectError: true, + errorMsg: "bucket is required", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage, err := New(tt.config) + + if tt.expectError { + s.Error(err) + if tt.errorMsg != "" { + s.Contains(err.Error(), tt.errorMsg) + } + s.Nil(storage) + } else { + s.NoError(err) + s.NotNil(storage) + s.Equal(tt.config.Bucket, storage.bucket) + s.Equal(tt.config.MaxSizeBytes, storage.maxSizeBytes) + + // Test prefix normalization + if tt.config.Prefix != "" { + s.NotContains(storage.prefix, "/", "prefix should not end with /") + } + } + }) + } +} + +func (s *S3StorageTestSuite) TestBuildKey() { + tests := []struct { + name string + prefix string + key string + expected string + }{ + { + name: "with prefix", + prefix: "packages", + key: "test/file.txt", + expected: "packages/test/file.txt", + }, + { + name: "without prefix", + prefix: "", + key: "test/file.txt", + expected: "test/file.txt", + }, + { + name: "with trailing slash in prefix", + prefix: "packages/", + key: "test/file.txt", + expected: "packages/test/file.txt", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage := &S3Storage{ + prefix: tt.prefix, + } + // Normalize prefix like in New() + if storage.prefix != "" && storage.prefix[len(storage.prefix)-1] == '/' { + storage.prefix = storage.prefix[:len(storage.prefix)-1] + } + + result := storage.buildKey(tt.key) + s.Equal(tt.expected, result) + }) + } +} + +func (s *S3StorageTestSuite) TestStripPrefix() { + tests := []struct { + name string + prefix string + key string + expected string + }{ + { + name: "with prefix", + prefix: "packages", + key: "packages/test/file.txt", + expected: "test/file.txt", + }, + { + name: "without prefix", + prefix: "", + key: "test/file.txt", + expected: "test/file.txt", + }, + { + name: "key without prefix but prefix set", + prefix: "packages", + key: "test/file.txt", + expected: "test/file.txt", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage := &S3Storage{ + prefix: tt.prefix, + } + + result := storage.stripPrefix(tt.key) + s.Equal(tt.expected, result) + }) + } +} + +func (s *S3StorageTestSuite) TestIsNotFoundError() { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + result := isNotFoundError(tt.err) + s.Equal(tt.expected, result) + }) + } +} + +func (s *S3StorageTestSuite) TestConfigDefaults() { + config := Config{ + Bucket: "test-bucket", + AccessKeyID: "test", + SecretAccessKey: "test", + } + + storage, err := New(config) + s.Require().NoError(err) + s.NotNil(storage) + + // Verify defaults + s.Equal("test-bucket", storage.bucket) + s.Equal("", storage.prefix) + s.Equal(int64(0), storage.maxSizeBytes) +} + +func (s *S3StorageTestSuite) TestPrefixNormalization() { + tests := []struct { + name string + inputPrefix string + expectedPrefix string + }{ + { + name: "prefix with trailing slash", + inputPrefix: "packages/", + expectedPrefix: "packages", + }, + { + name: "prefix without trailing slash", + inputPrefix: "packages", + expectedPrefix: "packages", + }, + { + name: "empty prefix", + inputPrefix: "", + expectedPrefix: "", + }, + { + name: "nested prefix with trailing slash", + inputPrefix: "cache/packages/", + expectedPrefix: "cache/packages", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + config := Config{ + Bucket: "test-bucket", + Prefix: tt.inputPrefix, + AccessKeyID: "test", + SecretAccessKey: "test", + } + + storage, err := New(config) + s.Require().NoError(err) + s.Equal(tt.expectedPrefix, storage.prefix) + }) + } +} + +func (s *S3StorageTestSuite) TestClose() { + config := Config{ + Bucket: "test-bucket", + AccessKeyID: "test", + SecretAccessKey: "test", + } + + storage, err := New(config) + s.Require().NoError(err) + + // Close should not error + err = storage.Close() + s.NoError(err) +} diff --git a/pkg/storage/smb/smb.go b/pkg/storage/smb/smb.go index d8d4d9a..7da7a45 100644 --- a/pkg/storage/smb/smb.go +++ b/pkg/storage/smb/smb.go @@ -1,42 +1,43 @@ package smb import ( - "bytes" "context" - "crypto/md5" // #nosec G501 -- MD5 used for file checksums, not cryptographic security - "crypto/sha256" - "encoding/hex" "fmt" "io" "net" "os" "path/filepath" "strings" - "sync" "time" "github.com/hirochachacha/go-smb2" "github.com/lukaszraczylo/gohoarder/pkg/errors" - "github.com/lukaszraczylo/gohoarder/pkg/metrics" "github.com/lukaszraczylo/gohoarder/pkg/storage" "github.com/rs/zerolog/log" ) -// SMBStorage implements storage.StorageBackend for SMB/CIFS shares -type SMBStorage struct { - host string - share string - basePath string - username string - password string - quota int64 - mu sync.RWMutex - used int64 - connPool chan *smbConnection - poolSize int +// Config holds SMB storage configuration +type Config struct { + Host string + Share string + Path string + Username string + Password string + Domain string + Port int + MaxSizeBytes int64 + PoolSize int } -// smbConnection wraps an SMB session and share +// SMBStorage implements storage.StorageBackend using SMB/CIFS +type SMBStorage struct { + connPool chan *smbConnection + config Config + maxSizeBytes int64 + poolSize int +} + +// smbConnection represents a pooled SMB connection type smbConnection struct { conn net.Conn session *smb2.Session @@ -44,27 +45,14 @@ type smbConnection struct { lastUse time.Time } -// Config holds SMB configuration -type Config struct { - Host string // SMB server hostname or IP - Port int // SMB server port (default: 445) - Share string // SMB share name - BasePath string // Base path within the share - Username string // SMB username - Password string // SMB password - Domain string // SMB domain (optional) - Quota int64 // Quota in bytes (0 = unlimited) - PoolSize int // Connection pool size (default: 5) -} - // New creates a new SMB storage backend -func New(ctx context.Context, cfg Config) (*SMBStorage, error) { +func New(cfg Config) (*SMBStorage, error) { if cfg.Host == "" { - return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required") + return nil, fmt.Errorf("SMB host is required") } if cfg.Share == "" { - return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required") + return nil, fmt.Errorf("SMB share is required") } if cfg.Port == 0 { @@ -75,64 +63,68 @@ func New(ctx context.Context, cfg Config) (*SMBStorage, error) { cfg.PoolSize = 5 // Default pool size } - smbStorage := &SMBStorage{ - host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), - share: cfg.Share, - basePath: strings.Trim(cfg.BasePath, "/\\"), - username: cfg.Username, - password: cfg.Password, - quota: cfg.Quota, - connPool: make(chan *smbConnection, cfg.PoolSize), - poolSize: cfg.PoolSize, + // Normalize path + cfg.Path = strings.Trim(cfg.Path, "/\\") + + storage := &SMBStorage{ + config: cfg, + maxSizeBytes: cfg.MaxSizeBytes, + poolSize: cfg.PoolSize, + connPool: make(chan *smbConnection, cfg.PoolSize), } - // Initialize connection pool + // Pre-populate connection pool for i := 0; i < cfg.PoolSize; i++ { - conn, err := smbStorage.createConnection(ctx) + conn, err := storage.createConnection() if err != nil { - // Clean up any created connections - close(smbStorage.connPool) - for c := range smbStorage.connPool { - c.close() - } - return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool") + log.Warn().Err(err).Int("attempt", i).Msg("Failed to create initial SMB connection") + continue } - smbStorage.connPool <- conn + storage.connPool <- conn } - // Calculate initial usage - if err := smbStorage.calculateUsage(ctx); err != nil { - log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage") - } + log.Info(). + Str("host", cfg.Host). + Int("port", cfg.Port). + Str("share", cfg.Share). + Str("path", cfg.Path). + Int("pool_size", cfg.PoolSize). + Msg("SMB storage initialized") - return smbStorage, nil + return storage, nil } // createConnection creates a new SMB connection -func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) { - conn, err := net.Dial("tcp", s.host) +func (s *SMBStorage) createConnection() (*smbConnection, error) { + // Connect to SMB server (use net.JoinHostPort for IPv6 compatibility) + addr := net.JoinHostPort(s.config.Host, fmt.Sprintf("%d", s.config.Port)) + conn, err := net.DialTimeout("tcp", addr, 10*time.Second) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect to SMB server: %w", err) } - dialer := &smb2.Dialer{ + // Create SMB dialer + d := &smb2.Dialer{ Initiator: &smb2.NTLMInitiator{ - User: s.username, - Password: s.password, + User: s.config.Username, + Password: s.config.Password, + Domain: s.config.Domain, }, } - session, err := dialer.Dial(conn) + // Establish SMB session + session, err := d.Dial(conn) if err != nil { - conn.Close() // #nosec G104 -- Cleanup, error not critical - return nil, err + _ = conn.Close() + return nil, fmt.Errorf("failed to establish SMB session: %w", err) } - share, err := session.Mount(s.share) + // Mount share + share, err := session.Mount(s.config.Share) if err != nil { - _ = session.Logoff() // #nosec G104 -- SMB cleanup - conn.Close() // #nosec G104 -- Cleanup, error not critical - return nil, err + _ = session.Logoff() + _ = conn.Close() + return nil, fmt.Errorf("failed to mount SMB share: %w", err) } return &smbConnection{ @@ -143,25 +135,34 @@ func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, erro }, nil } -// getConnection gets a connection from the pool -func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) { +// getConnection gets a connection from the pool or creates a new one +func (s *SMBStorage) getConnection() (*smbConnection, error) { select { case conn := <-s.connPool: + // Check if connection is still valid (not older than 5 minutes idle) + if time.Since(conn.lastUse) > 5*time.Minute { + conn.close() + return s.createConnection() + } conn.lastUse = time.Now() return conn, nil - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(30 * time.Second): - return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection") + default: + // Pool is empty, create new connection + return s.createConnection() } } // returnConnection returns a connection to the pool func (s *SMBStorage) returnConnection(conn *smbConnection) { + if conn == nil { + return + } + select { case s.connPool <- conn: + // Successfully returned to pool default: - // Pool is full, close the connection + // Pool is full, close connection conn.close() } } @@ -169,189 +170,161 @@ func (s *SMBStorage) returnConnection(conn *smbConnection) { // close closes an SMB connection func (c *smbConnection) close() { if c.share != nil { - _ = c.share.Umount() // #nosec G104 -- SMB cleanup + if err := c.share.Umount(); err != nil { + log.Warn().Err(err).Msg("Failed to unmount SMB share") + } } if c.session != nil { - _ = c.session.Logoff() // #nosec G104 -- SMB cleanup + if err := c.session.Logoff(); err != nil { + log.Warn().Err(err).Msg("Failed to logoff SMB session") + } } if c.conn != nil { - c.conn.Close() // #nosec G104 -- Cleanup, error not critical + if err := c.conn.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close SMB connection") + } } } -// Get retrieves a file from SMB share +// Get retrieves data from SMB share func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) { - conn, err := s.getConnection(ctx) + conn, err := s.getConnection() if err != nil { - return nil, err + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") } + defer s.returnConnection(conn) path := s.keyToPath(key) + log.Debug().Str("key", path).Msg("Getting file from SMB") + // Open file file, err := conn.share.Open(path) if err != nil { - s.returnConnection(conn) if os.IsNotExist(err) { - metrics.RecordStorageOperation("smb", "get", "not_found") - return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key)) + return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key)) } - metrics.RecordStorageOperation("smb", "get", "error") return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file") } - // Read entire file into memory and close SMB connection - // This is necessary because we need to return the connection to the pool + // Read entire file into memory (SMB files must be read completely before closing connection) data, err := io.ReadAll(file) - file.Close() // #nosec G104 -- Cleanup, error not critical - s.returnConnection(conn) - + if closeErr := file.Close(); closeErr != nil { + log.Warn().Err(closeErr).Str("path", path).Msg("Failed to close SMB file after reading") + } if err != nil { - metrics.RecordStorageOperation("smb", "get", "error") return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file") } - metrics.RecordStorageOperation("smb", "get", "success") - return io.NopCloser(bytes.NewReader(data)), nil + // Return as ReadCloser + return io.NopCloser(strings.NewReader(string(data))), nil } -// Put stores a file on SMB share +// Put stores data on SMB share func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error { - conn, err := s.getConnection(ctx) + conn, err := s.getConnection() if err != nil { - return err + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") } defer s.returnConnection(conn) path := s.keyToPath(key) - dir := filepath.Dir(path) - // Create directory structure - if err := conn.share.MkdirAll(dir, 0755); err != nil { - metrics.RecordStorageOperation("smb", "put", "error") + log.Debug().Str("key", path).Msg("Putting file to SMB") + + // Ensure directory exists + dir := filepath.Dir(path) + if err := s.ensureDir(conn, dir); err != nil { return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory") } - // Read data into buffer to calculate checksums and size - var buf bytes.Buffer - md5Hash := md5.New() // #nosec G401 -- MD5 used for file integrity check, not cryptographic security - sha256Hash := sha256.New() - multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash) - - written, err := io.Copy(multiWriter, data) + // Read data into buffer to check quota + buf := new(strings.Builder) + size, err := io.Copy(buf, data) if err != nil { - metrics.RecordStorageOperation("smb", "put", "error") return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data") } - // Check quota - if s.quota > 0 { - s.mu.RLock() - used := s.used - s.mu.RUnlock() - - if used+written > s.quota { - metrics.RecordStorageOperation("smb", "put", "quota_exceeded") - return errors.QuotaExceeded(s.quota) + // Check quota if set + if s.maxSizeBytes > 0 { + currentUsage, err := s.calculateUsage(conn) + if err != nil { + log.Warn().Err(err).Msg("Failed to calculate current usage, skipping quota check") + } else if currentUsage+size > s.maxSizeBytes { + return errors.QuotaExceeded(s.maxSizeBytes) } } - // Verify checksums if provided - if opts != nil { - md5Sum := hex.EncodeToString(md5Hash.Sum(nil)) - sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil)) - - if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum { - metrics.RecordStorageOperation("smb", "put", "checksum_error") - return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch") - } - - if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum { - metrics.RecordStorageOperation("smb", "put", "checksum_error") - return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch") - } - } - - // Create temp file for atomic write - tempPath := path + ".tmp" - file, err := conn.share.Create(tempPath) + // Create/overwrite file + file, err := conn.share.Create(path) if err != nil { - metrics.RecordStorageOperation("smb", "put", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB file") } + defer file.Close() // Write data - _, err = io.Copy(file, bytes.NewReader(buf.Bytes())) - file.Close() // #nosec G104 -- Cleanup, error not critical - + _, err = file.Write([]byte(buf.String())) if err != nil { - _ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup - metrics.RecordStorageOperation("smb", "put", "error") return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file") } - // Atomic rename - if err := conn.share.Rename(tempPath, path); err != nil { - _ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup - metrics.RecordStorageOperation("smb", "put", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file") - } - - // Update usage - s.mu.Lock() - s.used += written - currentUsed := s.used - s.mu.Unlock() - - metrics.RecordStorageOperation("smb", "put", "success") - metrics.UpdateCacheSize("smb", currentUsed) return nil } -// Delete removes a file from SMB share -func (s *SMBStorage) Delete(ctx context.Context, key string) error { - conn, err := s.getConnection(ctx) - if err != nil { +// ensureDir ensures a directory exists on SMB share +func (s *SMBStorage) ensureDir(conn *smbConnection, path string) error { + if path == "" || path == "." || path == "/" { + return nil + } + + // Try to stat the directory + _, err := conn.share.Stat(path) + if err == nil { + return nil // Directory exists + } + + // Create parent directory first + parent := filepath.Dir(path) + if parent != path && parent != "." && parent != "/" { + if err := s.ensureDir(conn, parent); err != nil { + return err + } + } + + // Create this directory + err = conn.share.Mkdir(path, 0755) + if err != nil && !os.IsExist(err) { return err } + + return nil +} + +// Delete removes data from SMB share +func (s *SMBStorage) Delete(ctx context.Context, key string) error { + conn, err := s.getConnection() + if err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") + } defer s.returnConnection(conn) path := s.keyToPath(key) - // Get size before deletion - info, err := conn.share.Stat(path) - if err != nil { - if os.IsNotExist(err) { - metrics.RecordStorageOperation("smb", "delete", "not_found") - return errors.NotFound(fmt.Sprintf("file not found: %s", key)) - } - metrics.RecordStorageOperation("smb", "delete", "error") - return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file") - } + log.Debug().Str("key", path).Msg("Deleting file from SMB") - size := info.Size() - - if err := conn.share.Remove(path); err != nil { - metrics.RecordStorageOperation("smb", "delete", "error") + err = conn.share.Remove(path) + if err != nil && !os.IsNotExist(err) { return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file") } - // Update usage - s.mu.Lock() - s.used -= size - currentUsed := s.used - s.mu.Unlock() - - metrics.RecordStorageOperation("smb", "delete", "success") - metrics.UpdateCacheSize("smb", currentUsed) return nil } -// Exists checks if a file exists on SMB share +// Exists checks if data exists on SMB share func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) { - conn, err := s.getConnection(ctx) + conn, err := s.getConnection() if err != nil { - return false, err + return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") } defer s.returnConnection(conn) @@ -368,57 +341,90 @@ func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) { return true, nil } -// List lists files with prefix on SMB share +// List returns a list of objects with the given prefix func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) { - conn, err := s.getConnection(ctx) + conn, err := s.getConnection() if err != nil { - return nil, err + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") } defer s.returnConnection(conn) - searchPath := s.keyToPath(prefix) + basePath := s.keyToPath(prefix) + + log.Debug().Str("prefix", basePath).Msg("Listing files in SMB") + var objects []storage.StorageObject - err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error { + // Walk the directory tree + err = s.walkPath(conn, basePath, func(path string, info os.FileInfo) error { if info.IsDir() { return nil } + // Convert path back to key key := s.pathToKey(path) + objects = append(objects, storage.StorageObject{ Key: key, Size: info.Size(), Modified: info.ModTime(), }) + return nil }) - if err != nil && !os.IsNotExist(err) { + if err != nil { return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files") } - // Apply pagination if requested - if opts != nil { - start := opts.Offset - end := len(objects) - if opts.MaxResults > 0 && start+opts.MaxResults < end { - end = start + opts.MaxResults - } - if start < len(objects) { - objects = objects[start:end] - } else { - objects = []storage.StorageObject{} - } - } - return objects, nil } -// Stat gets file metadata from SMB share -func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) { - conn, err := s.getConnection(ctx) +// walkPath walks a directory tree on SMB share +func (s *SMBStorage) walkPath(conn *smbConnection, root string, fn func(string, os.FileInfo) error) error { + // Check if root exists + info, err := conn.share.Stat(root) if err != nil { - return nil, err + if os.IsNotExist(err) { + return nil // Empty directory + } + return err + } + + // If root is a file, process it directly + if !info.IsDir() { + return fn(root, info) + } + + // List directory contents + entries, err := conn.share.ReadDir(root) + if err != nil { + return err + } + + for _, entry := range entries { + fullPath := filepath.Join(root, entry.Name()) + + if err := fn(fullPath, entry); err != nil { + return err + } + + // Recurse into subdirectories + if entry.IsDir() { + if err := s.walkPath(conn, fullPath, fn); err != nil { + return err + } + } + } + + return nil +} + +// Stat returns metadata about stored data +func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) { + conn, err := s.getConnection() + if err != nil { + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") } defer s.returnConnection(conn) @@ -427,7 +433,7 @@ func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo info, err := conn.share.Stat(path) if err != nil { if os.IsNotExist(err) { - return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key)) + return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key)) } return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file") } @@ -439,35 +445,35 @@ func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo }, nil } -// GetQuota returns quota information +// GetQuota returns current usage and quota information func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) { - s.mu.RLock() - used := s.used - s.mu.RUnlock() + conn, err := s.getConnection() + if err != nil { + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection") + } + defer s.returnConnection(conn) - available := s.quota - used - if available < 0 { - available = 0 + usage, err := s.calculateUsage(conn) + if err != nil { + return nil, err } return &storage.QuotaInfo{ - Used: used, - Available: available, - Limit: s.quota, + Used: usage, + Limit: s.maxSizeBytes, }, nil } -// Health checks SMB health +// Health checks if the SMB backend is healthy func (s *SMBStorage) Health(ctx context.Context) error { - conn, err := s.getConnection(ctx) + conn, err := s.getConnection() if err != nil { - return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed: cannot get connection") } defer s.returnConnection(conn) // Try to stat the base path - path := s.keyToPath("") - _, err = conn.share.Stat(path) + _, err = conn.share.Stat(s.config.Path) if err != nil && !os.IsNotExist(err) { return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed") } @@ -475,105 +481,62 @@ func (s *SMBStorage) Health(ctx context.Context) error { return nil } -// Close closes the storage backend +// Close closes the SMB storage backend func (s *SMBStorage) Close() error { close(s.connPool) + + // Close all connections in pool for conn := range s.connPool { conn.close() } + + log.Info().Msg("SMB storage closed") return nil } // keyToPath converts a storage key to SMB path func (s *SMBStorage) keyToPath(key string) string { - key = strings.TrimPrefix(key, "/") - key = filepath.Clean(key) + // Normalize separators to backslash for SMB + key = strings.ReplaceAll(key, "/", "\\") - // Remove path traversal attempts - for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") { - key = strings.TrimPrefix(key, "../") - key = strings.TrimPrefix(key, "..\\") + if s.config.Path == "" { + return key } - key = filepath.Clean(key) - if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") { - key = "" - } - - if s.basePath != "" { - return filepath.Join(s.basePath, key) - } - return key + // Use backslash for SMB paths + return s.config.Path + "\\" + key } -// pathToKey converts an SMB path back to a storage key +// pathToKey converts an SMB path to storage key func (s *SMBStorage) pathToKey(path string) string { - if s.basePath != "" { - path = strings.TrimPrefix(path, s.basePath) - path = strings.TrimPrefix(path, "/") - path = strings.TrimPrefix(path, "\\") + // Remove base path + if s.config.Path != "" { + path = strings.TrimPrefix(path, s.config.Path+"\\") } - return filepath.ToSlash(path) + + // Convert backslashes to forward slashes for consistency + return strings.ReplaceAll(path, "\\", "/") } -// walkPath recursively walks an SMB directory -func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error { - info, err := share.Stat(path) - if err != nil { - return err +// calculateUsage calculates total storage usage +func (s *SMBStorage) calculateUsage(conn *smbConnection) (int64, error) { + var totalSize int64 + + basePath := s.config.Path + if basePath == "" { + basePath = "\\" } - if !info.IsDir() { - return fn(path, info) - } - - entries, err := share.ReadDir(path) - if err != nil { - return err - } - - for _, entry := range entries { - entryPath := filepath.Join(path, entry.Name()) - if entry.IsDir() { - if err := s.walkPath(share, entryPath, fn); err != nil { - return err - } - } else { - if err := fn(entryPath, entry); err != nil { - return err - } - } - } - - return nil -} - -// calculateUsage calculates current SMB storage usage -func (s *SMBStorage) calculateUsage(ctx context.Context) error { - conn, err := s.getConnection(ctx) - if err != nil { - return err - } - defer s.returnConnection(conn) - - var total int64 - basePath := s.keyToPath("") - - err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error { + err := s.walkPath(conn, basePath, func(path string, info os.FileInfo) error { if !info.IsDir() { - total += info.Size() + totalSize += info.Size() } return nil }) - if err != nil && !os.IsNotExist(err) { - return err + if err != nil { + return 0, fmt.Errorf("failed to calculate usage: %w", err) } - s.mu.Lock() - s.used = total - s.mu.Unlock() - - metrics.UpdateCacheSize("smb", total) - return nil + return totalSize, nil } diff --git a/pkg/storage/smb/smb_test.go b/pkg/storage/smb/smb_test.go new file mode 100644 index 0000000..2c24f35 --- /dev/null +++ b/pkg/storage/smb/smb_test.go @@ -0,0 +1,327 @@ +package smb + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type SMBStorageTestSuite struct { + suite.Suite +} + +func TestSMBStorageTestSuite(t *testing.T) { + suite.Run(t, new(SMBStorageTestSuite)) +} + +func (s *SMBStorageTestSuite) TestNewSMBStorage() { + tests := []struct { + name string + errorMsg string + config Config + expectError bool + }{ + { + name: "valid config", + config: Config{ + Host: "fileserver.example.com", + Port: 445, + Share: "gohoarder", + Path: "packages", + Username: "testuser", + Password: "testpass", + Domain: "CORP", + MaxSizeBytes: 1024 * 1024, + PoolSize: 5, + }, + expectError: false, + }, + { + name: "missing host", + config: Config{ + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + }, + expectError: true, + errorMsg: "host is required", + }, + { + name: "missing share", + config: Config{ + Host: "fileserver.example.com", + Username: "testuser", + Password: "testpass", + }, + expectError: true, + errorMsg: "share is required", + }, + { + name: "default port", + config: Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + }, + expectError: false, + }, + { + name: "default pool size", + config: Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + }, + expectError: false, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage, err := New(tt.config) + + if tt.expectError { + s.Error(err) + if tt.errorMsg != "" { + s.Contains(err.Error(), tt.errorMsg) + } + s.Nil(storage) + } else { + // Note: This will fail in actual execution since we can't connect to a real SMB server + // But it tests the validation logic + if err != nil { + // Connection errors are expected in unit tests + s.Contains(err.Error(), "Failed to create initial SMB connection") + } + } + }) + } +} + +func (s *SMBStorageTestSuite) TestKeyToPath() { + tests := []struct { + name string + basePath string + key string + expectedWin string // Expected Windows-style path + }{ + { + name: "simple key with base path", + basePath: "packages", + key: "test/file.txt", + expectedWin: "packages\\test\\file.txt", + }, + { + name: "simple key without base path", + basePath: "", + key: "test/file.txt", + expectedWin: "test\\file.txt", + }, + { + name: "nested key", + basePath: "cache", + key: "deep/nested/path/file.txt", + expectedWin: "cache\\deep\\nested\\path\\file.txt", + }, + { + name: "key with backslashes", + basePath: "packages", + key: "test\\file.txt", + expectedWin: "packages\\test\\file.txt", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage := &SMBStorage{ + config: Config{ + Path: tt.basePath, + }, + } + + result := storage.keyToPath(tt.key) + s.Equal(tt.expectedWin, result) + }) + } +} + +func (s *SMBStorageTestSuite) TestPathToKey() { + tests := []struct { + name string + basePath string + path string + expected string + }{ + { + name: "windows path with base path", + basePath: "packages", + path: "packages\\test\\file.txt", + expected: "test/file.txt", + }, + { + name: "windows path without base path", + basePath: "", + path: "test\\file.txt", + expected: "test/file.txt", + }, + { + name: "nested windows path", + basePath: "cache", + path: "cache\\deep\\nested\\path\\file.txt", + expected: "deep/nested/path/file.txt", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + storage := &SMBStorage{ + config: Config{ + Path: tt.basePath, + }, + } + + result := storage.pathToKey(tt.path) + s.Equal(tt.expected, result) + }) + } +} + +func (s *SMBStorageTestSuite) TestConfigDefaults() { + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + } + + // This will fail to connect, but we can verify the config validation + _, err := New(config) + + // We expect a connection error, not a validation error + if err != nil { + s.NotContains(err.Error(), "host is required") + s.NotContains(err.Error(), "share is required") + } +} + +func (s *SMBStorageTestSuite) TestPathNormalization() { + tests := []struct { + name string + inputPath string + expectedPath string + }{ + { + name: "path with trailing slash", + inputPath: "packages/", + expectedPath: "packages", + }, + { + name: "path with trailing backslash", + inputPath: "packages\\", + expectedPath: "packages", + }, + { + name: "path without trailing slash", + inputPath: "packages", + expectedPath: "packages", + }, + { + name: "empty path", + inputPath: "", + expectedPath: "", + }, + { + name: "nested path with trailing slash", + inputPath: "cache/packages/", + expectedPath: "cache/packages", + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Path: tt.inputPath, + Username: "testuser", + Password: "testpass", + } + + // This will fail to connect, but we can check the config + storage, _ := New(config) + if storage != nil { + s.Equal(tt.expectedPath, storage.config.Path) + } + }) + } +} + +func (s *SMBStorageTestSuite) TestPoolSizeDefaults() { + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + } + + storage, _ := New(config) + if storage != nil { + s.Equal(5, storage.poolSize) // Default pool size + } +} + +func (s *SMBStorageTestSuite) TestPortDefaults() { + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + } + + storage, _ := New(config) + if storage != nil { + s.Equal(445, storage.config.Port) // Default SMB port + } +} + +func (s *SMBStorageTestSuite) TestClose() { + // Create a storage instance (will fail to connect but that's ok) + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + } + + storage, _ := New(config) + if storage != nil { + // Close should not panic + err := storage.Close() + s.NoError(err) + } +} + +func (s *SMBStorageTestSuite) TestConnectionPoolChannel() { + config := Config{ + Host: "fileserver.example.com", + Share: "gohoarder", + Username: "testuser", + Password: "testpass", + PoolSize: 10, + } + + storage, _ := New(config) + if storage != nil { + // Verify pool channel capacity + s.NotNil(storage.connPool) + s.Equal(10, cap(storage.connPool)) + } +} + +func (s *SMBStorageTestSuite) TestSMBConnectionStruct() { + // Verify smbConnection structure exists and has required fields + conn := &smbConnection{} + s.NotNil(conn) +} diff --git a/pkg/uuid/uuid_test.go b/pkg/uuid/uuid_test.go index 0aa77f6..71cdadf 100644 --- a/pkg/uuid/uuid_test.go +++ b/pkg/uuid/uuid_test.go @@ -56,8 +56,8 @@ func TestNew(t *testing.T) { func TestString(t *testing.T) { tests := []struct { name string - uuid UUID expected string + uuid UUID }{ { name: "zero UUID", diff --git a/pkg/vcs/git.go b/pkg/vcs/git.go index eab6518..03af71c 100644 --- a/pkg/vcs/git.go +++ b/pkg/vcs/git.go @@ -14,9 +14,9 @@ import ( // GitFetcher handles git repository operations type GitFetcher struct { + credStore *CredentialStore workDir string timeout time.Duration - credStore *CredentialStore } // NewGitFetcher creates a new git fetcher diff --git a/pkg/vcs/module.go b/pkg/vcs/module.go index 5a1b0bb..647db26 100644 --- a/pkg/vcs/module.go +++ b/pkg/vcs/module.go @@ -27,8 +27,8 @@ func NewModuleBuilder() *ModuleBuilder { // ModuleInfo represents Go module version metadata (.info file) type ModuleInfo struct { - Version string `json:"Version"` Time time.Time `json:"Time"` + Version string `json:"Version"` } // BuildModuleZip creates a Go module zip from source directory diff --git a/pkg/websocket/server.go b/pkg/websocket/server.go index 44c7333..047f33a 100644 --- a/pkg/websocket/server.go +++ b/pkg/websocket/server.go @@ -25,9 +25,9 @@ const ( // Event represents a WebSocket event message type Event struct { - Type EventType `json:"type"` Timestamp time.Time `json:"timestamp"` Data map[string]interface{} `json:"data"` + Type EventType `json:"type"` } // Client represents a WebSocket client connection @@ -45,15 +45,15 @@ type Server struct { broadcast chan Event register chan *Client unregister chan *Client - mu sync.RWMutex upgrader websocket.Upgrader + mu sync.RWMutex } // Config holds WebSocket server configuration type Config struct { + CheckOrigin func(r *http.Request) bool ReadBufferSize int WriteBufferSize int - CheckOrigin func(r *http.Request) bool } // NewServer creates a new WebSocket server @@ -307,8 +307,8 @@ func (c *Client) writePump() { // handleMessage processes incoming client messages func (c *Client) handleMessage(message []byte) { var msg struct { - Action string `json:"action"` Data interface{} `json:"data"` + Action string `json:"action"` } if err := json.Unmarshal(message, &msg); err != nil { @@ -379,10 +379,3 @@ func (c *Client) sendPong() { default: } } - -// GetConnectedClients returns the number of connected clients -func (s *Server) GetConnectedClients() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.clients) -}