mirror of
https://github.com/lukaszraczylo/gohoarder.git
synced 2026-06-16 01:01:20 +00:00
chore(schema): migrate to GORM V2 with multi-database support
- [x] Implement GORM V2 metadata store with SQLite, PostgreSQL, and MySQL support - [x] Add database migration system using gormigrate for schema versioning - [x] Create migration CLI tool with support for migrate, rollback, and status commands - [x] Add Docker support for migration container (Dockerfile.migrate) - [x] Implement automatic partition management for PostgreSQL time-series tables - [x] Add background aggregation worker for download statistics - [x] Support connection pooling configuration (max_open_conns, max_idle_conns, conn_max_lifetime) - [x] Add blocking mechanism based on vulnerability thresholds in stats and handlers - [x] Update Helm charts with migration init containers and multi-database configuration - [x] Replace deprecated SQLite store with optimized GORM implementation - [x] Add comprehensive integration tests for MySQL and PostgreSQL - [x] Update frontend to display blocked packages and storage utilization - [x] Add goreleaser configuration for migrate binary and container image - [x] Update configuration examples with database backend options and recommendations
This commit is contained in:
+72
-7
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/health"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file"
|
||||
metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite"
|
||||
metagorm "github.com/lukaszraczylo/gohoarder/pkg/metadata/gormstore"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/prewarming"
|
||||
@@ -119,18 +119,67 @@ func (a *App) initializeComponents() error {
|
||||
log.Info().Str("backend", a.config.Metadata.Backend).Msg("Initializing metadata store")
|
||||
switch a.config.Metadata.Backend {
|
||||
case "sqlite":
|
||||
a.metadata, err = metasqlite.New(metasqlite.Config{
|
||||
Path: a.config.Metadata.Connection,
|
||||
WALMode: a.config.Metadata.SQLite.WALMode,
|
||||
// Use GORM for SQLite
|
||||
a.metadata, err = metagorm.NewV2(metagorm.Config{
|
||||
Driver: "sqlite",
|
||||
DSN: metagorm.BuildSQLiteDSN(a.config.Metadata.SQLite.Path, a.config.Metadata.SQLite.WALMode),
|
||||
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
|
||||
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
|
||||
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
|
||||
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
|
||||
})
|
||||
|
||||
case "postgresql", "postgres":
|
||||
// Use GORM for PostgreSQL
|
||||
dsn := metagorm.BuildPostgresDSN(
|
||||
a.config.Metadata.PostgreSQL.Host,
|
||||
a.config.Metadata.PostgreSQL.Port,
|
||||
a.config.Metadata.PostgreSQL.User,
|
||||
a.config.Metadata.PostgreSQL.Password,
|
||||
a.config.Metadata.PostgreSQL.Database,
|
||||
getOrDefaultStr(a.config.Metadata.PostgreSQL.SSLMode, "disable"),
|
||||
)
|
||||
a.metadata, err = metagorm.NewV2(metagorm.Config{
|
||||
Driver: "postgres",
|
||||
DSN: dsn,
|
||||
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
|
||||
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
|
||||
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
|
||||
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
|
||||
})
|
||||
|
||||
case "mysql", "mariadb":
|
||||
// Use GORM for MySQL/MariaDB
|
||||
dsn := metagorm.BuildMySQLDSN(
|
||||
a.config.Metadata.MySQL.Host,
|
||||
a.config.Metadata.MySQL.Port,
|
||||
a.config.Metadata.MySQL.User,
|
||||
a.config.Metadata.MySQL.Password,
|
||||
a.config.Metadata.MySQL.Database,
|
||||
getOrDefaultStr(a.config.Metadata.MySQL.Charset, "utf8mb4"),
|
||||
)
|
||||
a.metadata, err = metagorm.NewV2(metagorm.Config{
|
||||
Driver: "mysql",
|
||||
DSN: dsn,
|
||||
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
|
||||
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
|
||||
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
|
||||
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
|
||||
})
|
||||
|
||||
case "file":
|
||||
// Keep file backend as-is for file-based metadata
|
||||
a.metadata, err = metafile.New(metafile.Config{
|
||||
Path: a.config.Metadata.Connection,
|
||||
})
|
||||
|
||||
default:
|
||||
a.metadata, err = metasqlite.New(metasqlite.Config{
|
||||
Path: "gohoarder.db",
|
||||
WALMode: false, // Default to DELETE mode for compatibility
|
||||
// Default to SQLite with GORM
|
||||
log.Warn().Str("backend", a.config.Metadata.Backend).Msg("Unknown metadata backend, defaulting to SQLite with GORM")
|
||||
a.metadata, err = metagorm.NewV2(metagorm.Config{
|
||||
Driver: "sqlite",
|
||||
DSN: metagorm.BuildSQLiteDSN("gohoarder.db", false),
|
||||
LogLevel: "warn",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
@@ -479,3 +528,19 @@ func (a *App) startAggregationWorker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getOrDefault returns the value if it's non-zero, otherwise returns the default
|
||||
func getOrDefault(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// getOrDefaultStr returns the value if it's non-empty, otherwise returns the default
|
||||
func getOrDefaultStr(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
+27
-6
@@ -142,6 +142,13 @@ func (a *App) handleListPackages(c *fiber.Ctx) error {
|
||||
severityCounts[strings.ToUpper(vuln.Severity)]++
|
||||
}
|
||||
|
||||
// Check if package should be blocked based on thresholds
|
||||
isBlocked := false
|
||||
if a.scanManager != nil {
|
||||
blocked, _, _ := a.scanManager.CheckVulnerabilities(ctx, pkg.Registry, entry.originalName, pkg.Version)
|
||||
isBlocked = blocked
|
||||
}
|
||||
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": true,
|
||||
"status": scanResult.Status,
|
||||
@@ -152,18 +159,21 @@ func (a *App) handleListPackages(c *fiber.Ctx) error {
|
||||
"moderate": severityCounts["MODERATE"],
|
||||
"low": severityCounts["LOW"],
|
||||
},
|
||||
"total": scanResult.VulnerabilityCount,
|
||||
"total": scanResult.VulnerabilityCount,
|
||||
"isBlocked": isBlocked,
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
"scanned": false,
|
||||
"status": "pending",
|
||||
"isBlocked": false,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pkgMap["vulnerabilities"] = map[string]interface{}{
|
||||
"scanned": false,
|
||||
"status": "not_scanned",
|
||||
"scanned": false,
|
||||
"status": "not_scanned",
|
||||
"isBlocked": false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,8 +361,9 @@ func (a *App) handleStats(c *fiber.Ctx) error {
|
||||
packages = []*metadata.Package{}
|
||||
}
|
||||
|
||||
// Calculate per-registry breakdown (exclude metadata entries like "list", "latest")
|
||||
// Calculate per-registry breakdown and blocked packages count
|
||||
registryStats := make(map[string]map[string]interface{})
|
||||
blockedCount := int64(0)
|
||||
|
||||
for _, pkg := range packages {
|
||||
// Skip metadata entries (npm metadata pages, pypi pages, etc.)
|
||||
@@ -371,6 +382,14 @@ func (a *App) handleStats(c *fiber.Ctx) error {
|
||||
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
|
||||
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
|
||||
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
|
||||
|
||||
// Check if package is blocked (only if security scanning is enabled and package is scanned)
|
||||
if a.config.Security.Enabled && a.scanManager != nil && pkg.SecurityScanned {
|
||||
blocked, _, _ := a.scanManager.CheckVulnerabilities(ctx, pkg.Registry, pkg.Name, pkg.Version)
|
||||
if blocked {
|
||||
blockedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine statistics using database stats for accuracy
|
||||
@@ -378,12 +397,14 @@ func (a *App) handleStats(c *fiber.Ctx) error {
|
||||
"total_packages": cacheStats.TotalPackages,
|
||||
"total_downloads": cacheStats.TotalDownloads,
|
||||
"total_size": cacheStats.TotalSize,
|
||||
"max_cache_size": a.config.Cache.MaxSizeBytes,
|
||||
"cache_hits": cacheStats.TotalDownloads,
|
||||
"cache_misses": 0, // TODO: Track cache misses
|
||||
"cache_evictions": 0, // TODO: Track evictions
|
||||
"cache_size": cacheStats.TotalSize,
|
||||
"scanned_packages": cacheStats.ScannedPackages,
|
||||
"vulnerable_packages": cacheStats.VulnerablePackages,
|
||||
"blocked_packages": blockedCount,
|
||||
}
|
||||
|
||||
// Convert registry stats to interface map
|
||||
|
||||
Vendored
+14
-7
@@ -203,9 +203,12 @@ func (m *Manager) getOrFetch(ctx context.Context, registry, name, version string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Skip security scan wait for metadata entries (index pages, lists, etc.)
|
||||
isMetadataEntry := version == "list" || version == "page" || version == "latest" || version == "metadata"
|
||||
|
||||
// Wait briefly for initial scan to complete if scanner is enabled
|
||||
// This prevents serving vulnerable packages on first request
|
||||
if m.scanner != nil {
|
||||
if m.scanner != nil && !isMetadataEntry {
|
||||
// Wait up to 30 seconds for scan to complete
|
||||
scanCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
@@ -360,15 +363,19 @@ func (m *Manager) store(ctx context.Context, registry, name, version string, dat
|
||||
Metadata: make(map[string]string),
|
||||
}
|
||||
|
||||
// Save metadata
|
||||
if err := m.metadata.SavePackage(ctx, pkg); err != nil {
|
||||
// Clean up storage if metadata save fails
|
||||
_ = m.storage.Delete(ctx, storageKey) // #nosec G104 -- Cleanup, error logged
|
||||
return nil, err
|
||||
// Save metadata (skip metadata entries like index pages, lists, etc.)
|
||||
isMetadataEntry := version == "list" || version == "page" || version == "latest" || version == "metadata"
|
||||
if !isMetadataEntry {
|
||||
if err := m.metadata.SavePackage(ctx, pkg); err != nil {
|
||||
// Clean up storage if metadata save fails
|
||||
_ = m.storage.Delete(ctx, storageKey) // #nosec G104 -- Cleanup, error logged
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Scan package if scanner is enabled (run in background to not block cache operations)
|
||||
if m.scanner != nil {
|
||||
// Skip scanning metadata entries (index pages, lists, etc.)
|
||||
if m.scanner != nil && !isMetadataEntry {
|
||||
go func() {
|
||||
scanCtx := context.Background()
|
||||
var filePath string
|
||||
|
||||
+29
-4
@@ -73,10 +73,17 @@ type SMBConfig struct {
|
||||
|
||||
// MetadataConfig contains metadata store configuration
|
||||
type MetadataConfig struct {
|
||||
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
|
||||
Backend string `mapstructure:"backend" json:"backend"`
|
||||
Connection string `mapstructure:"connection" json:"connection"`
|
||||
SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"`
|
||||
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
|
||||
MySQL MySQLConfig `mapstructure:"mysql" json:"mysql"`
|
||||
|
||||
// GORM-specific settings
|
||||
MaxOpenConns int `mapstructure:"max_open_conns" json:"max_open_conns"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns" json:"max_idle_conns"`
|
||||
ConnMaxLifetime int `mapstructure:"conn_max_lifetime" json:"conn_max_lifetime"` // seconds
|
||||
LogLevel string `mapstructure:"log_level" json:"log_level"` // "silent", "error", "warn", "info"
|
||||
}
|
||||
|
||||
// SQLiteConfig contains SQLite-specific configuration
|
||||
@@ -88,11 +95,22 @@ type SQLiteConfig struct {
|
||||
// PostgreSQLConfig contains PostgreSQL-specific configuration
|
||||
type PostgreSQLConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
Database string `mapstructure:"database" json:"database"`
|
||||
User string `mapstructure:"user" json:"user"`
|
||||
Password string `mapstructure:"password" json:"-"`
|
||||
SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
}
|
||||
|
||||
// MySQLConfig contains MySQL/MariaDB-specific configuration
|
||||
type MySQLConfig struct {
|
||||
Host string `mapstructure:"host" json:"host"`
|
||||
Port int `mapstructure:"port" json:"port"`
|
||||
Database string `mapstructure:"database" json:"database"`
|
||||
User string `mapstructure:"user" json:"user"`
|
||||
Password string `mapstructure:"password" json:"-"` // Don't serialize
|
||||
Charset string `mapstructure:"charset" json:"charset"`
|
||||
ParseTime bool `mapstructure:"parse_time" json:"parse_time"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache management configuration
|
||||
@@ -415,9 +433,16 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
|
||||
// Validate metadata backend
|
||||
validMetadataBackends := map[string]bool{"sqlite": true, "postgresql": true, "file": true}
|
||||
validMetadataBackends := map[string]bool{
|
||||
"sqlite": true,
|
||||
"postgresql": true,
|
||||
"postgres": true,
|
||||
"mysql": true,
|
||||
"mariadb": true,
|
||||
"file": true,
|
||||
}
|
||||
if !validMetadataBackends[c.Metadata.Backend] {
|
||||
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, file; got %s", c.Metadata.Backend)
|
||||
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, mysql, file; got %s", c.Metadata.Backend)
|
||||
}
|
||||
|
||||
// Validate cache
|
||||
|
||||
@@ -0,0 +1,357 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AggregationWorker handles background aggregation of download statistics
|
||||
type AggregationWorker struct {
|
||||
db *gorm.DB
|
||||
stopChan chan struct{}
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
// NewAggregationWorker creates a new aggregation worker
|
||||
func NewAggregationWorker(db *gorm.DB) *AggregationWorker {
|
||||
return &AggregationWorker{
|
||||
db: db,
|
||||
stopChan: make(chan struct{}),
|
||||
ticker: time.NewTicker(1 * time.Hour), // Run every hour
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the aggregation worker
|
||||
func (w *AggregationWorker) Start() {
|
||||
log.Info().Msg("Starting aggregation worker")
|
||||
|
||||
// Run immediately on start
|
||||
if err := w.AggregateHourly(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to run initial hourly aggregation")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ticker.C:
|
||||
if err := w.AggregateHourly(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to aggregate hourly stats")
|
||||
}
|
||||
|
||||
// Check if it's time for daily aggregation (run at midnight)
|
||||
now := time.Now()
|
||||
if now.Hour() == 0 {
|
||||
if err := w.AggregateDaily(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to aggregate daily stats")
|
||||
}
|
||||
}
|
||||
|
||||
case <-w.stopChan:
|
||||
log.Info().Msg("Stopping aggregation worker")
|
||||
w.ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the aggregation worker
|
||||
func (w *AggregationWorker) Stop() {
|
||||
close(w.stopChan)
|
||||
}
|
||||
|
||||
// AggregateHourly aggregates download events into hourly stats
|
||||
func (w *AggregationWorker) AggregateHourly() error {
|
||||
startTime := time.Now()
|
||||
log.Debug().Msg("Starting hourly aggregation")
|
||||
|
||||
// Get dialect name
|
||||
dialectName := w.db.Dialector.Name()
|
||||
|
||||
// Calculate cutoff time (aggregate events older than 5 minutes to avoid partial data)
|
||||
cutoff := time.Now().Add(-5 * time.Minute).Truncate(time.Hour)
|
||||
|
||||
return w.db.Transaction(func(tx *gorm.DB) error {
|
||||
var aggregateSQL string
|
||||
|
||||
switch dialectName {
|
||||
case "postgres":
|
||||
// PostgreSQL: Use date_trunc for time bucketing
|
||||
aggregateSQL = `
|
||||
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
de.registry_id,
|
||||
de.package_id,
|
||||
date_trunc('hour', de.downloaded_at) AS time_bucket,
|
||||
COUNT(*) AS download_count,
|
||||
COUNT(DISTINCT de.ip_address) AS unique_ips,
|
||||
COUNT(*) FILTER (WHERE de.authenticated = true) AS auth_downloads,
|
||||
NOW() AS created_at,
|
||||
NOW() AS updated_at
|
||||
FROM download_events de
|
||||
WHERE de.downloaded_at < ?
|
||||
GROUP BY de.registry_id, de.package_id, time_bucket
|
||||
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
|
||||
DO UPDATE SET
|
||||
download_count = download_stats_hourly.download_count + EXCLUDED.download_count,
|
||||
unique_ips = GREATEST(download_stats_hourly.unique_ips, EXCLUDED.unique_ips),
|
||||
auth_downloads = download_stats_hourly.auth_downloads + EXCLUDED.auth_downloads,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
case "mysql":
|
||||
// MySQL: Use DATE_FORMAT for time bucketing
|
||||
aggregateSQL = `
|
||||
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
de.registry_id,
|
||||
de.package_id,
|
||||
DATE_FORMAT(de.downloaded_at, '%Y-%m-%d %H:00:00') AS time_bucket,
|
||||
COUNT(*) AS download_count,
|
||||
COUNT(DISTINCT de.ip_address) AS unique_ips,
|
||||
SUM(CASE WHEN de.authenticated = true THEN 1 ELSE 0 END) AS auth_downloads,
|
||||
NOW() AS created_at,
|
||||
NOW() AS updated_at
|
||||
FROM download_events de
|
||||
WHERE de.downloaded_at < ?
|
||||
GROUP BY de.registry_id, de.package_id, time_bucket
|
||||
ON DUPLICATE KEY UPDATE
|
||||
download_count = download_stats_hourly.download_count + VALUES(download_count),
|
||||
unique_ips = GREATEST(download_stats_hourly.unique_ips, VALUES(unique_ips)),
|
||||
auth_downloads = download_stats_hourly.auth_downloads + VALUES(auth_downloads),
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
default: // SQLite
|
||||
// SQLite: Use strftime for time bucketing
|
||||
// Note: SQLite doesn't support UPSERT as elegantly, need to handle separately
|
||||
aggregateSQL = `
|
||||
INSERT OR REPLACE INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
de.registry_id,
|
||||
de.package_id,
|
||||
strftime('%Y-%m-%d %H:00:00', de.downloaded_at) AS time_bucket,
|
||||
COUNT(*) AS download_count,
|
||||
COUNT(DISTINCT de.ip_address) AS unique_ips,
|
||||
SUM(CASE WHEN de.authenticated = 1 THEN 1 ELSE 0 END) AS auth_downloads,
|
||||
datetime('now') AS created_at,
|
||||
datetime('now') AS updated_at
|
||||
FROM download_events de
|
||||
WHERE de.downloaded_at < ?
|
||||
GROUP BY de.registry_id, de.package_id, time_bucket
|
||||
`
|
||||
}
|
||||
|
||||
// Execute aggregation
|
||||
if err := tx.Exec(aggregateSQL, cutoff).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete aggregated events (older than 24 hours to keep recent data for debugging)
|
||||
deleteOlder := time.Now().Add(-24 * time.Hour)
|
||||
deleteResult := tx.Exec("DELETE FROM download_events WHERE downloaded_at < ?", deleteOlder)
|
||||
if deleteResult.Error != nil {
|
||||
return deleteResult.Error
|
||||
}
|
||||
|
||||
// Also update package-level stats (NULL package_id = registry totals)
|
||||
var registryAggSQL string
|
||||
if dialectName == "postgres" {
|
||||
registryAggSQL = `
|
||||
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
NULL as package_id,
|
||||
time_bucket,
|
||||
SUM(download_count) as download_count,
|
||||
SUM(unique_ips) as unique_ips,
|
||||
SUM(auth_downloads) as auth_downloads,
|
||||
NOW() as created_at,
|
||||
NOW() as updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE package_id IS NOT NULL
|
||||
GROUP BY registry_id, time_bucket
|
||||
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
|
||||
DO UPDATE SET
|
||||
download_count = EXCLUDED.download_count,
|
||||
unique_ips = EXCLUDED.unique_ips,
|
||||
auth_downloads = EXCLUDED.auth_downloads,
|
||||
updated_at = NOW()
|
||||
`
|
||||
} else if dialectName == "mysql" {
|
||||
registryAggSQL = `
|
||||
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
NULL as package_id,
|
||||
time_bucket,
|
||||
SUM(download_count) as download_count,
|
||||
SUM(unique_ips) as unique_ips,
|
||||
SUM(auth_downloads) as auth_downloads,
|
||||
NOW() as created_at,
|
||||
NOW() as updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE package_id IS NOT NULL
|
||||
GROUP BY registry_id, time_bucket
|
||||
ON DUPLICATE KEY UPDATE
|
||||
download_count = VALUES(download_count),
|
||||
unique_ips = VALUES(unique_ips),
|
||||
auth_downloads = VALUES(auth_downloads),
|
||||
updated_at = NOW()
|
||||
`
|
||||
} else {
|
||||
// SQLite
|
||||
registryAggSQL = `
|
||||
INSERT OR REPLACE INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
NULL as package_id,
|
||||
time_bucket,
|
||||
SUM(download_count) as download_count,
|
||||
SUM(unique_ips) as unique_ips,
|
||||
SUM(auth_downloads) as auth_downloads,
|
||||
datetime('now') as created_at,
|
||||
datetime('now') as updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE package_id IS NOT NULL
|
||||
GROUP BY registry_id, time_bucket
|
||||
`
|
||||
}
|
||||
|
||||
if err := tx.Exec(registryAggSQL).Error; err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to aggregate registry totals (continuing anyway)")
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
log.Info().
|
||||
Int64("deleted_events", deleteResult.RowsAffected).
|
||||
Dur("duration", elapsed).
|
||||
Msg("Completed hourly aggregation")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// AggregateDaily aggregates hourly stats into daily stats
|
||||
func (w *AggregationWorker) AggregateDaily() error {
|
||||
startTime := time.Now()
|
||||
log.Debug().Msg("Starting daily aggregation")
|
||||
|
||||
dialectName := w.db.Dialector.Name()
|
||||
|
||||
// Aggregate yesterday's data
|
||||
yesterday := time.Now().AddDate(0, 0, -1).Truncate(24 * time.Hour)
|
||||
dayEnd := yesterday.Add(24 * time.Hour)
|
||||
|
||||
return w.db.Transaction(func(tx *gorm.DB) error {
|
||||
var aggregateSQL string
|
||||
|
||||
switch dialectName {
|
||||
case "postgres":
|
||||
aggregateSQL = `
|
||||
INSERT INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
package_id,
|
||||
date_trunc('day', time_bucket) AS time_bucket,
|
||||
SUM(download_count) AS download_count,
|
||||
MAX(unique_ips) AS unique_ips,
|
||||
SUM(auth_downloads) AS auth_downloads,
|
||||
'{}' AS top_user_agents,
|
||||
NOW() AS created_at,
|
||||
NOW() AS updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE time_bucket >= ? AND time_bucket < ?
|
||||
GROUP BY registry_id, package_id, date_trunc('day', time_bucket)
|
||||
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
|
||||
DO UPDATE SET
|
||||
download_count = EXCLUDED.download_count,
|
||||
unique_ips = EXCLUDED.unique_ips,
|
||||
auth_downloads = EXCLUDED.auth_downloads,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
case "mysql":
|
||||
aggregateSQL = `
|
||||
INSERT INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
package_id,
|
||||
DATE_FORMAT(time_bucket, '%Y-%m-%d 00:00:00') AS time_bucket,
|
||||
SUM(download_count) AS download_count,
|
||||
MAX(unique_ips) AS unique_ips,
|
||||
SUM(auth_downloads) AS auth_downloads,
|
||||
'{}' AS top_user_agents,
|
||||
NOW() AS created_at,
|
||||
NOW() AS updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE time_bucket >= ? AND time_bucket < ?
|
||||
GROUP BY registry_id, package_id, DATE_FORMAT(time_bucket, '%Y-%m-%d 00:00:00')
|
||||
ON DUPLICATE KEY UPDATE
|
||||
download_count = VALUES(download_count),
|
||||
unique_ips = VALUES(unique_ips),
|
||||
auth_downloads = VALUES(auth_downloads),
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
default: // SQLite
|
||||
aggregateSQL = `
|
||||
INSERT OR REPLACE INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
|
||||
SELECT
|
||||
registry_id,
|
||||
package_id,
|
||||
date(time_bucket) AS time_bucket,
|
||||
SUM(download_count) AS download_count,
|
||||
MAX(unique_ips) AS unique_ips,
|
||||
SUM(auth_downloads) AS auth_downloads,
|
||||
'{}' AS top_user_agents,
|
||||
datetime('now') AS created_at,
|
||||
datetime('now') AS updated_at
|
||||
FROM download_stats_hourly
|
||||
WHERE time_bucket >= ? AND time_bucket < ?
|
||||
GROUP BY registry_id, package_id, date(time_bucket)
|
||||
`
|
||||
}
|
||||
|
||||
if err := tx.Exec(aggregateSQL, yesterday, dayEnd).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete old hourly stats (keep last 7 days)
|
||||
deleteOlder := time.Now().AddDate(0, 0, -7)
|
||||
deleteResult := tx.Exec("DELETE FROM download_stats_hourly WHERE time_bucket < ?", deleteOlder)
|
||||
if deleteResult.Error != nil {
|
||||
return deleteResult.Error
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
log.Info().
|
||||
Int64("deleted_hourly_stats", deleteResult.RowsAffected).
|
||||
Dur("duration", elapsed).
|
||||
Msg("Completed daily aggregation")
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdatePackageAccessCounts synchronizes package access_count from download stats
|
||||
func (w *AggregationWorker) UpdatePackageAccessCounts() error {
|
||||
log.Debug().Msg("Updating package access counts")
|
||||
|
||||
// Update from download_stats_hourly (sum all-time downloads per package)
|
||||
updateSQL := `
|
||||
UPDATE packages p
|
||||
SET access_count = COALESCE((
|
||||
SELECT SUM(download_count)
|
||||
FROM download_stats_hourly dsh
|
||||
WHERE dsh.package_id = p.id
|
||||
), 0)
|
||||
`
|
||||
|
||||
if err := w.db.Exec(updateSQL).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msg("Updated package access counts")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
)
|
||||
|
||||
// Config holds GORM store configuration
|
||||
type Config struct {
|
||||
// Database connection
|
||||
Driver string // "sqlite", "postgres", "mysql"
|
||||
DSN string // Data Source Name
|
||||
|
||||
// Connection pool
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
// GORM settings
|
||||
LogLevel string // "silent", "error", "warn", "info"
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() error {
|
||||
if c.Driver == "" {
|
||||
return errors.New(errors.ErrCodeInvalidConfig, "driver is required")
|
||||
}
|
||||
if c.DSN == "" {
|
||||
return errors.New(errors.ErrCodeInvalidConfig, "DSN is required")
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if c.MaxOpenConns == 0 {
|
||||
c.MaxOpenConns = 25
|
||||
}
|
||||
if c.MaxIdleConns == 0 {
|
||||
c.MaxIdleConns = 5
|
||||
}
|
||||
if c.ConnMaxLifetime == 0 {
|
||||
c.ConnMaxLifetime = time.Hour
|
||||
}
|
||||
if c.LogLevel == "" {
|
||||
c.LogLevel = "warn"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildPostgresDSN builds PostgreSQL DSN from structured config
|
||||
func BuildPostgresDSN(host string, port int, user, password, database, sslmode string) string {
|
||||
if sslmode == "" {
|
||||
sslmode = "disable"
|
||||
}
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
host, port, user, password, database, sslmode)
|
||||
}
|
||||
|
||||
// BuildMySQLDSN builds MySQL/MariaDB DSN from structured config
|
||||
func BuildMySQLDSN(host string, port int, user, password, database, charset string) string {
|
||||
if charset == "" {
|
||||
charset = "utf8mb4"
|
||||
}
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||
user, password, host, port, database, charset)
|
||||
}
|
||||
|
||||
// BuildSQLiteDSN builds SQLite DSN with pragmas
|
||||
func BuildSQLiteDSN(path string, walMode bool) string {
|
||||
if path == "" {
|
||||
path = "gohoarder.db"
|
||||
}
|
||||
if walMode {
|
||||
return fmt.Sprintf("%s?_journal_mode=WAL&_busy_timeout=5000&_synchronous=NORMAL&_cache_size=2000", path)
|
||||
}
|
||||
return fmt.Sprintf("%s?_journal_mode=DELETE&_busy_timeout=5000&_synchronous=NORMAL", path)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
// MySQLV2IntegrationTestSuite embeds the V2 test suite with MySQL container
|
||||
type MySQLV2IntegrationTestSuite struct {
|
||||
GORMStoreV2TestSuite
|
||||
container *mysql.MySQLContainer
|
||||
}
|
||||
|
||||
// SetupSuite runs once before all tests
|
||||
func (s *MySQLV2IntegrationTestSuite) SetupSuite() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MySQL container
|
||||
container, err := mysql.RunContainer(ctx,
|
||||
testcontainers.WithImage("mysql:8.0"),
|
||||
mysql.WithDatabase("testdb"),
|
||||
mysql.WithUsername("testuser"),
|
||||
mysql.WithPassword("testpass"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("port: 3306 MySQL Community Server").
|
||||
WithOccurrence(1).
|
||||
WithStartupTimeout(60*time.Second),
|
||||
),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.container = container
|
||||
}
|
||||
|
||||
// TearDownSuite runs once after all tests
|
||||
func (s *MySQLV2IntegrationTestSuite) TearDownSuite() {
|
||||
if s.container != nil {
|
||||
ctx := context.Background()
|
||||
err := s.container.Terminate(ctx)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest runs before each test
|
||||
func (s *MySQLV2IntegrationTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
|
||||
// Get connection string from container
|
||||
connStr, err := s.container.ConnectionString(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create GORM store with MySQL
|
||||
cfg := Config{
|
||||
Driver: "mysql",
|
||||
DSN: connStr,
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 5,
|
||||
ConnMaxLifetime: 3600 * time.Second,
|
||||
LogLevel: "silent",
|
||||
}
|
||||
|
||||
s.store, err = NewV2(cfg)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(s.store)
|
||||
}
|
||||
|
||||
// TearDownTest runs after each test
|
||||
func (s *MySQLV2IntegrationTestSuite) TearDownTest() {
|
||||
if s.store != nil {
|
||||
// Clean up all tables for next test
|
||||
tables := []string{
|
||||
"download_events",
|
||||
"download_stats_hourly",
|
||||
"download_stats_daily",
|
||||
"audit_log",
|
||||
"cve_bypasses",
|
||||
"scan_results",
|
||||
"package_vulnerabilities",
|
||||
"vulnerabilities",
|
||||
"package_metadata",
|
||||
"packages",
|
||||
"registries",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
s.store.db.Exec("TRUNCATE TABLE " + table)
|
||||
}
|
||||
|
||||
// Re-seed default registries after truncate
|
||||
defaultRegistries := []RegistryModel{
|
||||
{Name: "npm", DisplayName: "NPM Registry", UpstreamURL: "https://registry.npmjs.org", Enabled: true, ScanByDefault: true},
|
||||
{Name: "pypi", DisplayName: "PyPI", UpstreamURL: "https://pypi.org", Enabled: true, ScanByDefault: true},
|
||||
{Name: "go", DisplayName: "Go Modules", UpstreamURL: "https://proxy.golang.org", Enabled: true, ScanByDefault: true},
|
||||
}
|
||||
for _, reg := range defaultRegistries {
|
||||
s.store.db.Create(®)
|
||||
}
|
||||
|
||||
// Rebuild registry cache
|
||||
s.store.rebuildRegistryCache()
|
||||
|
||||
s.store.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestMySQLV2IntegrationTestSuite runs the integration test suite with MySQL
|
||||
func TestMySQLV2IntegrationTestSuite(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
suite.Run(t, new(MySQLV2IntegrationTestSuite))
|
||||
}
|
||||
|
||||
// Test_MySQLV2_SpecificFeatures tests MySQL-specific features
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_SpecificFeatures() {
|
||||
// Test that we're actually using MySQL
|
||||
var version string
|
||||
err := s.store.db.Raw("SELECT VERSION()").Scan(&version).Error
|
||||
s.NoError(err)
|
||||
s.Contains(version, "MySQL")
|
||||
}
|
||||
|
||||
// Test_MySQLV2_NoPartitioning tests that partition manager is nil for MySQL
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_NoPartitioning() {
|
||||
// MySQL doesn't use our partition manager (uses native partitioning differently)
|
||||
s.Nil(s.store.partitionManager)
|
||||
}
|
||||
|
||||
// Test_MySQLV2_HighConcurrency tests MySQL's concurrent write support
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_HighConcurrency() {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "concurrent-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/concurrent-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// MySQL can handle concurrent writes (with InnoDB row-level locking)
|
||||
concurrency := 15
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all updates succeeded
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(concurrency), retrieved.DownloadCount)
|
||||
}
|
||||
|
||||
// Test_MySQLV2_JSON tests MySQL JSON functionality
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_JSON() {
|
||||
metadata := map[string]interface{}{
|
||||
"author": "Test Author",
|
||||
"license": "MIT",
|
||||
"description": "Test package",
|
||||
"keywords": []interface{}{"test", "mysql", "json"},
|
||||
}
|
||||
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "json-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/json-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Retrieve and verify JSON data
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "json-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.NotNil(retrieved.Metadata)
|
||||
s.Equal("MIT", retrieved.Metadata["license"])
|
||||
s.Equal("Test Author", retrieved.Metadata["author"])
|
||||
}
|
||||
|
||||
// Test_MySQLV2_TransactionRollback tests MySQL transaction rollback
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_TransactionRollback() {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "tx-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/tx-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Try to update with invalid data that should trigger rollback
|
||||
err = s.store.db.Transaction(func(tx *gorm.DB) error {
|
||||
// First update succeeds
|
||||
result := tx.Model(&PackageModel{}).
|
||||
Where("registry_id = ? AND name = ? AND version = ?",
|
||||
s.store.registryCache["npm"], "tx-test", "1.0.0").
|
||||
Update("access_count", gorm.Expr("access_count + ?", 1))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// Second operation fails (invalid foreign key)
|
||||
invalidModel := &PackageModel{
|
||||
RegistryID: 9999, // Non-existent registry
|
||||
Name: "invalid",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "invalid",
|
||||
Size: 100,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
return tx.Create(invalidModel).Error
|
||||
})
|
||||
|
||||
// Transaction should fail
|
||||
s.Error(err)
|
||||
|
||||
// Verify first update was rolled back
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "tx-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(0), retrieved.DownloadCount) // Should still be 0, not 1
|
||||
}
|
||||
|
||||
// Test_MySQLV2_CharacterSet tests MySQL UTF-8 support
|
||||
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_CharacterSet() {
|
||||
// Test package with Unicode characters
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "unicode-test-世界-🚀",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/unicode-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
Metadata: map[string]interface{}{
|
||||
"description": "Test with emoji 🎉 and Chinese 中文",
|
||||
},
|
||||
}
|
||||
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Retrieve and verify Unicode data preserved
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "unicode-test-世界-🚀", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal("unicode-test-世界-🚀", retrieved.Name)
|
||||
s.Contains(retrieved.Metadata["description"], "🎉")
|
||||
s.Contains(retrieved.Metadata["description"], "中文")
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
// PostgresV2IntegrationTestSuite embeds the V2 test suite with PostgreSQL container
|
||||
type PostgresV2IntegrationTestSuite struct {
|
||||
GORMStoreV2TestSuite
|
||||
container *postgres.PostgresContainer
|
||||
}
|
||||
|
||||
// SetupSuite runs once before all tests
|
||||
func (s *PostgresV2IntegrationTestSuite) SetupSuite() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start PostgreSQL container
|
||||
container, err := postgres.RunContainer(ctx,
|
||||
testcontainers.WithImage("postgres:16-alpine"),
|
||||
postgres.WithDatabase("testdb"),
|
||||
postgres.WithUsername("testuser"),
|
||||
postgres.WithPassword("testpass"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(60*time.Second),
|
||||
),
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.container = container
|
||||
}
|
||||
|
||||
// TearDownSuite runs once after all tests
|
||||
func (s *PostgresV2IntegrationTestSuite) TearDownSuite() {
|
||||
if s.container != nil {
|
||||
ctx := context.Background()
|
||||
err := s.container.Terminate(ctx)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest runs before each test
|
||||
func (s *PostgresV2IntegrationTestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
|
||||
// Get connection string from container
|
||||
connStr, err := s.container.ConnectionString(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create GORM store with PostgreSQL
|
||||
cfg := Config{
|
||||
Driver: "postgres",
|
||||
DSN: connStr + "sslmode=disable",
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 5,
|
||||
ConnMaxLifetime: 3600 * time.Second,
|
||||
LogLevel: "silent",
|
||||
}
|
||||
|
||||
s.store, err = NewV2(cfg)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(s.store)
|
||||
}
|
||||
|
||||
// TearDownTest runs after each test
|
||||
func (s *PostgresV2IntegrationTestSuite) TearDownTest() {
|
||||
if s.store != nil {
|
||||
// Clean up all tables for next test
|
||||
tables := []string{
|
||||
"download_events",
|
||||
"download_stats_hourly",
|
||||
"download_stats_daily",
|
||||
"audit_log",
|
||||
"cve_bypasses",
|
||||
"scan_results",
|
||||
"package_vulnerabilities",
|
||||
"vulnerabilities",
|
||||
"package_metadata",
|
||||
"packages",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
s.store.db.Exec("TRUNCATE TABLE " + table + " CASCADE")
|
||||
}
|
||||
|
||||
s.store.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgresV2IntegrationTestSuite runs the integration test suite with PostgreSQL
|
||||
func TestPostgresV2IntegrationTestSuite(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration tests in short mode")
|
||||
}
|
||||
|
||||
suite.Run(t, new(PostgresV2IntegrationTestSuite))
|
||||
}
|
||||
|
||||
// Test_PostgresV2_SpecificFeatures tests PostgreSQL-specific features
|
||||
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_SpecificFeatures() {
|
||||
// Test that we're actually using PostgreSQL
|
||||
var version string
|
||||
err := s.store.db.Raw("SELECT version()").Scan(&version).Error
|
||||
s.NoError(err)
|
||||
s.Contains(version, "PostgreSQL")
|
||||
}
|
||||
|
||||
// Test_PostgresV2_Partitioning tests partition manager
|
||||
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_Partitioning() {
|
||||
s.NotNil(s.store.partitionManager)
|
||||
|
||||
// Get partition info
|
||||
info, err := s.store.partitionManager.GetPartitionInfo()
|
||||
s.NoError(err)
|
||||
s.NotNil(info)
|
||||
|
||||
// Should have created partitions
|
||||
downloadPartitions := info["download_events_partitions"].(int64)
|
||||
s.Greater(downloadPartitions, int64(0))
|
||||
|
||||
auditPartitions := info["audit_log_partitions"].(int64)
|
||||
s.Greater(auditPartitions, int64(0))
|
||||
}
|
||||
|
||||
// Test_PostgresV2_HighConcurrency tests PostgreSQL's excellent concurrent write support
|
||||
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_HighConcurrency() {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "concurrent-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/concurrent-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// PostgreSQL can handle many concurrent writes
|
||||
concurrency := 20
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all updates succeeded
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(concurrency), retrieved.DownloadCount)
|
||||
}
|
||||
|
||||
// Test_PostgresV2_JSONB tests PostgreSQL JSONB functionality
|
||||
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_JSONB() {
|
||||
metadata := map[string]interface{}{
|
||||
"author": "Test Author",
|
||||
"license": "MIT",
|
||||
"description": "Test package",
|
||||
"keywords": []interface{}{"test", "postgres", "jsonb"},
|
||||
}
|
||||
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "jsonb-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/jsonb-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Retrieve and verify JSONB data
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "jsonb-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.NotNil(retrieved.Metadata)
|
||||
s.Equal("MIT", retrieved.Metadata["license"])
|
||||
s.Equal("Test Author", retrieved.Metadata["author"])
|
||||
}
|
||||
@@ -0,0 +1,871 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// GORMStoreV2TestSuite is the test suite for V2 GORM implementation
|
||||
type GORMStoreV2TestSuite struct {
|
||||
suite.Suite
|
||||
db *gorm.DB
|
||||
store *GORMStoreV2
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// SetupSuite runs once before all tests
|
||||
func (s *GORMStoreV2TestSuite) SetupSuite() {
|
||||
// Use in-memory SQLite for fast tests
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.db = db
|
||||
}
|
||||
|
||||
// SetupTest runs before each test
|
||||
func (s *GORMStoreV2TestSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
|
||||
// Create fresh store with V2 schema
|
||||
cfg := Config{
|
||||
Driver: "sqlite",
|
||||
DSN: "file::memory:?cache=shared",
|
||||
MaxOpenConns: 10,
|
||||
MaxIdleConns: 5,
|
||||
ConnMaxLifetime: 3600 * time.Second,
|
||||
LogLevel: "silent",
|
||||
}
|
||||
|
||||
store, err := NewV2(cfg)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(store)
|
||||
|
||||
s.store = store
|
||||
}
|
||||
|
||||
// TearDownTest runs after each test
|
||||
func (s *GORMStoreV2TestSuite) TearDownTest() {
|
||||
if s.store != nil {
|
||||
// Clean up all tables for next test
|
||||
tables := []string{
|
||||
"audit_log",
|
||||
"download_stats_daily",
|
||||
"download_stats_hourly",
|
||||
"download_events",
|
||||
"cve_bypasses",
|
||||
"scan_results",
|
||||
"package_vulnerabilities",
|
||||
"vulnerabilities",
|
||||
"package_metadata",
|
||||
"packages",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
s.store.db.Exec(fmt.Sprintf("DELETE FROM %s", table))
|
||||
}
|
||||
|
||||
s.store.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestGORMStoreV2TestSuite runs the test suite
|
||||
func TestGORMStoreV2TestSuite(t *testing.T) {
|
||||
suite.Run(t, new(GORMStoreV2TestSuite))
|
||||
}
|
||||
|
||||
// Test_V2_SavePackage_Success tests saving a package
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_Success() {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "test-package",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/test-package/1.0.0.tgz",
|
||||
Size: 12345,
|
||||
ChecksumMD5: "abc123",
|
||||
ChecksumSHA256: "def456",
|
||||
UpstreamURL: "https://registry.npmjs.org/test-package",
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify package was saved
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "test-package", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.NotNil(retrieved)
|
||||
s.Equal("npm", retrieved.Registry)
|
||||
s.Equal("test-package", retrieved.Name)
|
||||
s.Equal("1.0.0", retrieved.Version)
|
||||
s.Equal(int64(12345), retrieved.Size)
|
||||
}
|
||||
|
||||
// Test_V2_SavePackage_WithMetadata tests saving package with metadata
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_WithMetadata() {
|
||||
metadataMap := map[string]string{
|
||||
"author": "Test Author",
|
||||
"license": "MIT",
|
||||
"homepage": "https://example.com",
|
||||
"description": "Test package description",
|
||||
}
|
||||
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "meta-package",
|
||||
Version: "2.0.0",
|
||||
StorageKey: "npm/meta-package/2.0.0.tgz",
|
||||
Size: 5000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
Metadata: metadataMap,
|
||||
}
|
||||
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify metadata was saved in separate table
|
||||
var pkgMetadata PackageMetadataModel
|
||||
err = s.store.db.Where("package_id = (SELECT id FROM packages WHERE name = ?)", "meta-package").
|
||||
First(&pkgMetadata).Error
|
||||
s.NoError(err)
|
||||
s.Equal("Test Author", pkgMetadata.Author)
|
||||
s.Equal("MIT", pkgMetadata.License)
|
||||
s.Equal("https://example.com", pkgMetadata.Homepage)
|
||||
}
|
||||
|
||||
// Test_V2_SavePackage_Upsert tests update on conflict
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_Upsert() {
|
||||
// Save initial package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "upsert-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/upsert-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Update same package
|
||||
pkg.Size = 2000
|
||||
pkg.ChecksumMD5 = "updated"
|
||||
err = s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "upsert-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(2000), retrieved.Size)
|
||||
s.Equal("updated", retrieved.ChecksumMD5)
|
||||
}
|
||||
|
||||
// Test_V2_GetPackage_NotFound tests getting non-existent package
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_GetPackage_NotFound() {
|
||||
_, err := s.store.GetPackage(s.ctx, "npm", "nonexistent", "1.0.0")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
// Test_V2_DeletePackage_Success tests soft delete
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_DeletePackage_Success() {
|
||||
// Save package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "delete-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/delete-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Delete package (soft delete)
|
||||
err = s.store.DeletePackage(s.ctx, "npm", "delete-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
|
||||
// Verify deleted (should not be found)
|
||||
_, err = s.store.GetPackage(s.ctx, "npm", "delete-test", "1.0.0")
|
||||
s.Error(err)
|
||||
|
||||
// Verify soft delete (deleted_at set)
|
||||
var count int64
|
||||
s.store.db.Unscoped().Model(&PackageModel{}).
|
||||
Where("name = ?", "delete-test").
|
||||
Count(&count)
|
||||
s.Equal(int64(1), count) // Still in DB, just soft deleted
|
||||
}
|
||||
|
||||
// Test_V2_ListPackages_All tests listing all packages
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_All() {
|
||||
// Create multiple packages
|
||||
for i := 0; i < 5; i++ {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: fmt.Sprintf("package-%d", i),
|
||||
Version: "1.0.0",
|
||||
StorageKey: fmt.Sprintf("npm/package-%d/1.0.0.tgz", i),
|
||||
Size: int64(i * 1000),
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// List all packages
|
||||
packages, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{})
|
||||
s.NoError(err)
|
||||
s.Len(packages, 5)
|
||||
}
|
||||
|
||||
// Test_V2_ListPackages_FilterByRegistry tests filtering by registry
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_FilterByRegistry() {
|
||||
// Create packages in different registries
|
||||
registries := []string{"npm", "pypi", "go"}
|
||||
for _, reg := range registries {
|
||||
pkg := &metadata.Package{
|
||||
Registry: reg,
|
||||
Name: "test-package",
|
||||
Version: "1.0.0",
|
||||
StorageKey: fmt.Sprintf("%s/test-package/1.0.0", reg),
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Filter by npm registry
|
||||
packages, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
|
||||
Registry: "npm",
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Len(packages, 1)
|
||||
s.Equal("npm", packages[0].Registry)
|
||||
}
|
||||
|
||||
// Test_V2_ListPackages_Pagination tests pagination
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_Pagination() {
|
||||
// Create 10 packages
|
||||
for i := 0; i < 10; i++ {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: fmt.Sprintf("package-%d", i),
|
||||
Version: "1.0.0",
|
||||
StorageKey: fmt.Sprintf("npm/package-%d/1.0.0.tgz", i),
|
||||
Size: int64(i * 1000),
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Get first page (5 items)
|
||||
page1, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
|
||||
Limit: 5,
|
||||
Offset: 0,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Len(page1, 5)
|
||||
|
||||
// Get second page (5 items)
|
||||
page2, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
|
||||
Limit: 5,
|
||||
Offset: 5,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Len(page2, 5)
|
||||
|
||||
// Verify different packages
|
||||
s.NotEqual(page1[0].Name, page2[0].Name)
|
||||
}
|
||||
|
||||
// Test_V2_UpdateDownloadCount_Success tests incrementing download count
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_UpdateDownloadCount_Success() {
|
||||
// Create package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "download-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/download-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Update download count
|
||||
err = s.store.UpdateDownloadCount(s.ctx, "npm", "download-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
|
||||
// Verify count incremented
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "download-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(1), retrieved.DownloadCount)
|
||||
|
||||
// Update again
|
||||
err = s.store.UpdateDownloadCount(s.ctx, "npm", "download-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
|
||||
retrieved, err = s.store.GetPackage(s.ctx, "npm", "download-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(2), retrieved.DownloadCount)
|
||||
|
||||
// Verify download event was recorded
|
||||
var eventCount int64
|
||||
s.store.db.Model(&DownloadEventModel{}).Count(&eventCount)
|
||||
s.Equal(int64(2), eventCount)
|
||||
}
|
||||
|
||||
// Test_V2_Count tests counting packages
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_Count() {
|
||||
// Initially zero
|
||||
count, err := s.store.Count(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0, count)
|
||||
|
||||
// Create 3 packages
|
||||
for i := 0; i < 3; i++ {
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: fmt.Sprintf("count-test-%d", i),
|
||||
Version: "1.0.0",
|
||||
StorageKey: fmt.Sprintf("npm/count-test-%d/1.0.0.tgz", i),
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
count, err = s.store.Count(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(3, count)
|
||||
}
|
||||
|
||||
// Test_V2_GetStats tests aggregated statistics
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_GetStats() {
|
||||
// Create packages in different registries
|
||||
packages := []*metadata.Package{
|
||||
{Registry: "npm", Name: "pkg1", Version: "1.0.0", StorageKey: "npm/pkg1/1.0.0.tgz", Size: 1000, CachedAt: time.Now(), LastAccessed: time.Now()},
|
||||
{Registry: "npm", Name: "pkg2", Version: "1.0.0", StorageKey: "npm/pkg2/1.0.0.tgz", Size: 2000, CachedAt: time.Now(), LastAccessed: time.Now()},
|
||||
{Registry: "pypi", Name: "pkg3", Version: "1.0.0", StorageKey: "pypi/pkg3/1.0.0.tar.gz", Size: 3000, CachedAt: time.Now(), LastAccessed: time.Now()},
|
||||
}
|
||||
|
||||
for _, pkg := range packages {
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Update download counts
|
||||
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg1", "1.0.0")
|
||||
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg1", "1.0.0")
|
||||
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg2", "1.0.0")
|
||||
|
||||
// Get stats for all registries
|
||||
statsAll, err := s.store.GetStats(s.ctx, "")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), statsAll.TotalPackages)
|
||||
s.Equal(int64(6000), statsAll.TotalSize)
|
||||
s.Equal(int64(3), statsAll.TotalDownloads)
|
||||
|
||||
// Get stats for npm registry
|
||||
statsNpm, err := s.store.GetStats(s.ctx, "npm")
|
||||
s.NoError(err)
|
||||
s.Equal("npm", statsNpm.Registry)
|
||||
s.Equal(int64(2), statsNpm.TotalPackages)
|
||||
s.Equal(int64(3000), statsNpm.TotalSize)
|
||||
s.Equal(int64(3), statsNpm.TotalDownloads)
|
||||
|
||||
// Get stats for pypi registry
|
||||
statsPypi, err := s.store.GetStats(s.ctx, "pypi")
|
||||
s.NoError(err)
|
||||
s.Equal("pypi", statsPypi.Registry)
|
||||
s.Equal(int64(1), statsPypi.TotalPackages)
|
||||
s.Equal(int64(3000), statsPypi.TotalSize)
|
||||
}
|
||||
|
||||
// Test_V2_Health tests database health check
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_Health() {
|
||||
err := s.store.Health(s.ctx)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Test_V2_RegistryCache tests registry caching
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_RegistryCache() {
|
||||
// Default registries should be cached
|
||||
s.Contains(s.store.registryCache, "npm")
|
||||
s.Contains(s.store.registryCache, "pypi")
|
||||
s.Contains(s.store.registryCache, "go")
|
||||
|
||||
// Get registry ID from cache
|
||||
npmID, err := s.store.getRegistryID("npm")
|
||||
s.NoError(err)
|
||||
s.Greater(npmID, int32(0))
|
||||
|
||||
// Second call should use cache (no DB query)
|
||||
npmID2, err := s.store.getRegistryID("npm")
|
||||
s.NoError(err)
|
||||
s.Equal(npmID, npmID2)
|
||||
|
||||
// Non-existent registry
|
||||
_, err = s.store.getRegistryID("nonexistent")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "not found")
|
||||
}
|
||||
|
||||
// Test_V2_SoftDelete tests soft delete behavior
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SoftDelete() {
|
||||
// Create package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "soft-delete",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/soft-delete/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Delete
|
||||
err = s.store.DeletePackage(s.ctx, "npm", "soft-delete", "1.0.0")
|
||||
s.NoError(err)
|
||||
|
||||
// Count should not include deleted
|
||||
count, err := s.store.Count(s.ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(0, count)
|
||||
|
||||
// But record still exists with deleted_at set
|
||||
var pkgModel PackageModel
|
||||
err = s.store.db.Unscoped().Where("name = ?", "soft-delete").First(&pkgModel).Error
|
||||
s.NoError(err)
|
||||
s.NotNil(pkgModel.DeletedAt)
|
||||
}
|
||||
|
||||
// Test_V2_AggregationWorker tests that aggregation worker is initialized
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_AggregationWorker() {
|
||||
s.NotNil(s.store.aggregationWorker)
|
||||
}
|
||||
|
||||
// Test_V2_ConcurrentUpdates tests concurrent download count updates
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_ConcurrentUpdates() {
|
||||
// Create package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "concurrent-test",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "npm/concurrent-test/1.0.0.tgz",
|
||||
Size: 1000,
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// SQLite: Sequential updates only (write lock prevents concurrent writes)
|
||||
updateCount := 5
|
||||
for i := 0; i < updateCount; i++ {
|
||||
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Verify all updates succeeded
|
||||
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.Equal(int64(updateCount), retrieved.DownloadCount)
|
||||
}
|
||||
|
||||
// Test_V2_SaveScanResult tests saving a scan result
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SaveScanResult() {
|
||||
// Create a package first
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "test-package",
|
||||
Version: "1.0.0",
|
||||
StorageKey: "/cache/npm/test-package-1.0.0.tgz",
|
||||
Size: 1024,
|
||||
UpstreamURL: "https://registry.npmjs.org/test-package",
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Create and save a scan result
|
||||
scanResult := &metadata.ScanResult{
|
||||
Registry: "npm",
|
||||
PackageName: "test-package",
|
||||
PackageVersion: "1.0.0",
|
||||
Scanner: "trivy",
|
||||
Status: metadata.ScanStatusVulnerable,
|
||||
ScannedAt: time.Now(),
|
||||
Vulnerabilities: []metadata.Vulnerability{
|
||||
{
|
||||
ID: "CVE-2024-0001",
|
||||
Severity: "HIGH",
|
||||
Title: "Test vulnerability",
|
||||
Description: "Test description",
|
||||
References: []string{"https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2024-0001"},
|
||||
},
|
||||
{
|
||||
ID: "CVE-2024-0002",
|
||||
Severity: "CRITICAL",
|
||||
Title: "Critical vulnerability",
|
||||
Description: "Critical test description",
|
||||
References: []string{"https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2024-0002"},
|
||||
},
|
||||
},
|
||||
VulnerabilityCount: 2,
|
||||
Details: map[string]interface{}{
|
||||
"scan_duration": 42,
|
||||
"scanner_version": "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
err = s.store.SaveScanResult(s.ctx, scanResult)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify the scan result was saved and package was updated
|
||||
retrievedPkg, err := s.store.GetPackage(s.ctx, "npm", "test-package", "1.0.0")
|
||||
s.NoError(err)
|
||||
s.True(retrievedPkg.SecurityScanned)
|
||||
}
|
||||
|
||||
// Test_V2_GetScanResult tests retrieving a scan result
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_GetScanResult() {
|
||||
// Create a package
|
||||
pkg := &metadata.Package{
|
||||
Registry: "npm",
|
||||
Name: "scan-test",
|
||||
Version: "2.0.0",
|
||||
StorageKey: "/cache/npm/scan-test-2.0.0.tgz",
|
||||
Size: 2048,
|
||||
UpstreamURL: "https://registry.npmjs.org/scan-test",
|
||||
CachedAt: time.Now(),
|
||||
LastAccessed: time.Now(),
|
||||
}
|
||||
err := s.store.SavePackage(s.ctx, pkg)
|
||||
s.NoError(err)
|
||||
|
||||
// Save a scan result with vulnerabilities
|
||||
scanResult := &metadata.ScanResult{
|
||||
Registry: "npm",
|
||||
PackageName: "scan-test",
|
||||
PackageVersion: "2.0.0",
|
||||
Scanner: "grype",
|
||||
Status: metadata.ScanStatusVulnerable,
|
||||
ScannedAt: time.Now(),
|
||||
Vulnerabilities: []metadata.Vulnerability{
|
||||
{
|
||||
ID: "CVE-2024-1234",
|
||||
Severity: "HIGH",
|
||||
Title: "Test High Severity",
|
||||
Description: "High severity test",
|
||||
References: []string{"https://example.com/cve-2024-1234"},
|
||||
FixedIn: "2.1.0",
|
||||
},
|
||||
{
|
||||
ID: "CVE-2024-5678",
|
||||
Severity: "MODERATE",
|
||||
Title: "Test Moderate Severity",
|
||||
Description: "Moderate severity test",
|
||||
References: []string{"https://example.com/cve-2024-5678"},
|
||||
},
|
||||
},
|
||||
VulnerabilityCount: 2,
|
||||
}
|
||||
err = s.store.SaveScanResult(s.ctx, scanResult)
|
||||
s.NoError(err)
|
||||
|
||||
// Retrieve the scan result
|
||||
retrieved, err := s.store.GetScanResult(s.ctx, "npm", "scan-test", "2.0.0")
|
||||
s.NoError(err)
|
||||
s.NotNil(retrieved)
|
||||
s.Equal("grype", retrieved.Scanner)
|
||||
s.Equal(metadata.ScanStatusVulnerable, retrieved.Status)
|
||||
s.Equal(2, retrieved.VulnerabilityCount)
|
||||
s.Len(retrieved.Vulnerabilities, 2)
|
||||
|
||||
// Verify vulnerability details are retrieved correctly
|
||||
s.Equal("CVE-2024-1234", retrieved.Vulnerabilities[0].ID)
|
||||
s.Equal("HIGH", retrieved.Vulnerabilities[0].Severity)
|
||||
s.Equal("Test High Severity", retrieved.Vulnerabilities[0].Title)
|
||||
s.Equal("2.1.0", retrieved.Vulnerabilities[0].FixedIn)
|
||||
s.Len(retrieved.Vulnerabilities[0].References, 1)
|
||||
}
|
||||
|
||||
// Test_V2_GetScanResult_NotFound tests retrieving a non-existent scan result
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_GetScanResult_NotFound() {
|
||||
_, err := s.store.GetScanResult(s.ctx, "npm", "nonexistent", "1.0.0")
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Test_V2_SaveCVEBypass tests saving a CVE bypass
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SaveCVEBypass() {
|
||||
bypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0001",
|
||||
Reason: "False positive - not applicable to our use case",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour), // 30 days
|
||||
NotifyOnExpiry: true,
|
||||
Active: true,
|
||||
}
|
||||
|
||||
err := s.store.SaveCVEBypass(s.ctx, bypass)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(bypass.ID)
|
||||
s.NotZero(bypass.CreatedAt)
|
||||
}
|
||||
|
||||
// Test_V2_SaveCVEBypass_Update tests updating an existing CVE bypass
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_SaveCVEBypass_Update() {
|
||||
// Create initial bypass
|
||||
bypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0002",
|
||||
Reason: "Initial reason",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
NotifyOnExpiry: false,
|
||||
Active: true,
|
||||
}
|
||||
err := s.store.SaveCVEBypass(s.ctx, bypass)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(bypass.ID)
|
||||
|
||||
// Update the bypass
|
||||
bypass.Reason = "Updated reason"
|
||||
bypass.NotifyOnExpiry = true
|
||||
err = s.store.SaveCVEBypass(s.ctx, bypass)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Test_V2_GetActiveCVEBypasses tests retrieving active CVE bypasses
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_GetActiveCVEBypasses() {
|
||||
// Create active bypass with unique target
|
||||
uniqueTarget := fmt.Sprintf("CVE-2024-TEST-%d", time.Now().UnixNano())
|
||||
activeBypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: uniqueTarget,
|
||||
Reason: "Active bypass",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
Active: true,
|
||||
}
|
||||
err := s.store.SaveCVEBypass(s.ctx, activeBypass)
|
||||
s.NoError(err)
|
||||
|
||||
// Create expired bypass
|
||||
expiredBypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0004",
|
||||
Reason: "Expired bypass",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired yesterday
|
||||
Active: true,
|
||||
}
|
||||
err = s.store.SaveCVEBypass(s.ctx, expiredBypass)
|
||||
s.NoError(err)
|
||||
|
||||
// Create inactive bypass
|
||||
inactiveBypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0005",
|
||||
Reason: "Inactive bypass",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
Active: false,
|
||||
}
|
||||
err = s.store.SaveCVEBypass(s.ctx, inactiveBypass)
|
||||
s.NoError(err)
|
||||
|
||||
// Retrieve active bypasses
|
||||
bypasses, err := s.store.GetActiveCVEBypasses(s.ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Should contain our active bypass, but may contain others from parallel tests
|
||||
found := false
|
||||
for _, b := range bypasses {
|
||||
if b.Target == uniqueTarget {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
// All bypasses should be active and non-expired
|
||||
s.True(b.Active)
|
||||
s.True(b.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
s.True(found, "Should find our unique active bypass")
|
||||
}
|
||||
|
||||
// Test_V2_ListCVEBypasses tests listing CVE bypasses with filters
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_ListCVEBypasses() {
|
||||
// Create multiple bypasses with unique targets
|
||||
nano := time.Now().UnixNano()
|
||||
bypasses := []*metadata.CVEBypass{
|
||||
{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: fmt.Sprintf("CVE-2024-LIST-%d-1", nano),
|
||||
Reason: "Test 1",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
Active: true,
|
||||
},
|
||||
{
|
||||
Type: metadata.BypassTypePackage,
|
||||
Target: fmt.Sprintf("npm/vulnerable-package@%d", nano),
|
||||
Reason: "Test 2",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(15 * 24 * time.Hour),
|
||||
Active: true,
|
||||
},
|
||||
{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: fmt.Sprintf("CVE-2024-LIST-%d-2", nano),
|
||||
Reason: "Test 3",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||
Active: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, b := range bypasses {
|
||||
err := s.store.SaveCVEBypass(s.ctx, b)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// List only CVE type
|
||||
opts := &metadata.BypassListOptions{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
}
|
||||
cveOnly, err := s.store.ListCVEBypasses(s.ctx, opts)
|
||||
s.NoError(err)
|
||||
for _, b := range cveOnly {
|
||||
s.Equal(metadata.BypassTypeCVE, b.Type)
|
||||
}
|
||||
|
||||
// List only non-expired
|
||||
opts = &metadata.BypassListOptions{
|
||||
IncludeExpired: false,
|
||||
}
|
||||
nonExpired, err := s.store.ListCVEBypasses(s.ctx, opts)
|
||||
s.NoError(err)
|
||||
for _, b := range nonExpired {
|
||||
s.True(b.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
// Test pagination
|
||||
opts = &metadata.BypassListOptions{
|
||||
Limit: 1,
|
||||
Offset: 0,
|
||||
}
|
||||
page1, err := s.store.ListCVEBypasses(s.ctx, opts)
|
||||
s.NoError(err)
|
||||
s.LessOrEqual(len(page1), 1) // Should be at most 1
|
||||
}
|
||||
|
||||
// Test_V2_DeleteCVEBypass tests deleting a CVE bypass
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass() {
|
||||
// Create a bypass
|
||||
bypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0008",
|
||||
Reason: "To be deleted",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
Active: true,
|
||||
}
|
||||
err := s.store.SaveCVEBypass(s.ctx, bypass)
|
||||
s.NoError(err)
|
||||
s.NotEmpty(bypass.ID)
|
||||
|
||||
// Delete the bypass
|
||||
err = s.store.DeleteCVEBypass(s.ctx, bypass.ID)
|
||||
s.NoError(err)
|
||||
|
||||
// Verify it's no longer in active bypasses
|
||||
active, err := s.store.GetActiveCVEBypasses(s.ctx)
|
||||
s.NoError(err)
|
||||
for _, b := range active {
|
||||
s.NotEqual(bypass.ID, b.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test_V2_DeleteCVEBypass_NotFound tests deleting a non-existent bypass
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass_NotFound() {
|
||||
err := s.store.DeleteCVEBypass(s.ctx, "99999999")
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Test_V2_DeleteCVEBypass_InvalidID tests deleting with invalid ID
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass_InvalidID() {
|
||||
err := s.store.DeleteCVEBypass(s.ctx, "invalid-id")
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Test_V2_CleanupExpiredBypasses tests cleaning up expired bypasses
|
||||
func (s *GORMStoreV2TestSuite) Test_V2_CleanupExpiredBypasses() {
|
||||
// Create expired bypasses
|
||||
for i := 0; i < 3; i++ {
|
||||
bypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: fmt.Sprintf("CVE-2024-00%d", 10+i),
|
||||
Reason: "Expired bypass",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired
|
||||
Active: true,
|
||||
}
|
||||
err := s.store.SaveCVEBypass(s.ctx, bypass)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// Create active bypass (should not be deleted)
|
||||
activeBypass := &metadata.CVEBypass{
|
||||
Type: metadata.BypassTypeCVE,
|
||||
Target: "CVE-2024-0999",
|
||||
Reason: "Active bypass",
|
||||
CreatedBy: "admin@example.com",
|
||||
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
|
||||
Active: true,
|
||||
}
|
||||
err := s.store.SaveCVEBypass(s.ctx, activeBypass)
|
||||
s.NoError(err)
|
||||
|
||||
// Cleanup expired bypasses
|
||||
count, err := s.store.CleanupExpiredBypasses(s.ctx)
|
||||
s.NoError(err)
|
||||
s.GreaterOrEqual(count, 3) // At least the 3 we just created
|
||||
|
||||
// Verify active bypass is still there
|
||||
active, err := s.store.GetActiveCVEBypasses(s.ctx)
|
||||
s.NoError(err)
|
||||
found := false
|
||||
for _, b := range active {
|
||||
if b.Target == "CVE-2024-0999" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.True(found, "Active bypass should still exist")
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"github.com/go-gormigrate/gormigrate/v2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GetMigrations returns all database migrations for gormigrate
|
||||
func GetMigrations() []*gormigrate.Migration {
|
||||
return []*gormigrate.Migration{
|
||||
{
|
||||
ID: "202601030001",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Migration: Create V2 schema
|
||||
return migrateToV2(tx)
|
||||
},
|
||||
Rollback: func(tx *gorm.DB) error {
|
||||
// Rollback: Drop V2 schema (careful!)
|
||||
return rollbackFromV2(tx)
|
||||
},
|
||||
},
|
||||
// Future migrations go here
|
||||
// {
|
||||
// ID: "202601040001",
|
||||
// Migrate: func(tx *gorm.DB) error {
|
||||
// // Add new column, index, etc.
|
||||
// return tx.Exec("ALTER TABLE packages ADD COLUMN new_field VARCHAR(255)").Error
|
||||
// },
|
||||
// Rollback: func(tx *gorm.DB) error {
|
||||
// return tx.Exec("ALTER TABLE packages DROP COLUMN new_field").Error
|
||||
// },
|
||||
// },
|
||||
}
|
||||
}
|
||||
|
||||
// migrateToV2 creates the complete V2 schema
|
||||
func migrateToV2(tx *gorm.DB) error {
|
||||
// Get dialect name for database-specific features
|
||||
dialectName := tx.Dialector.Name()
|
||||
|
||||
// Step 1: Create all tables using GORM AutoMigrate
|
||||
// This handles cross-database compatibility automatically
|
||||
if err := tx.AutoMigrate(GetAllModels()...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Seed default registries
|
||||
registries := []RegistryModel{
|
||||
{Name: "npm", DisplayName: "NPM Registry", UpstreamURL: "https://registry.npmjs.org", Enabled: true, ScanByDefault: true},
|
||||
{Name: "pypi", DisplayName: "PyPI", UpstreamURL: "https://pypi.org", Enabled: true, ScanByDefault: true},
|
||||
{Name: "go", DisplayName: "Go Modules", UpstreamURL: "https://proxy.golang.org", Enabled: true, ScanByDefault: true},
|
||||
}
|
||||
|
||||
for _, reg := range registries {
|
||||
// Upsert: create if not exists
|
||||
if err := tx.Where("name = ?", reg.Name).FirstOrCreate(®).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Create database-specific optimizations
|
||||
if dialectName == "postgres" {
|
||||
if err := createPostgreSQLOptimizations(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if dialectName == "mysql" {
|
||||
if err := createMySQLOptimizations(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPostgreSQLOptimizations adds PostgreSQL-specific features
|
||||
func createPostgreSQLOptimizations(tx *gorm.DB) error {
|
||||
optimizations := []string{
|
||||
// Create GIN indexes for JSONB columns
|
||||
`CREATE INDEX IF NOT EXISTS idx_package_metadata_keywords_gin
|
||||
ON package_metadata USING GIN(keywords)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_package_metadata_raw_gin
|
||||
ON package_metadata USING GIN(raw_metadata)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_references_gin
|
||||
ON vulnerabilities USING GIN(references)`,
|
||||
|
||||
// Create partial indexes (only non-deleted records)
|
||||
`CREATE INDEX IF NOT EXISTS idx_packages_active
|
||||
ON packages(registry_id, name, version) WHERE deleted_at IS NULL`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_packages_vulnerable
|
||||
ON packages(vulnerability_count, highest_severity)
|
||||
WHERE vulnerability_count > 0 AND deleted_at IS NULL`,
|
||||
|
||||
// Create view for vulnerable packages
|
||||
`CREATE OR REPLACE VIEW v_vulnerable_packages AS
|
||||
SELECT
|
||||
r.name AS registry,
|
||||
p.name,
|
||||
p.version,
|
||||
p.vulnerability_count,
|
||||
p.highest_severity,
|
||||
p.last_scanned_at
|
||||
FROM packages p
|
||||
JOIN registries r ON p.registry_id = r.id
|
||||
WHERE p.vulnerability_count > 0 AND p.deleted_at IS NULL
|
||||
ORDER BY
|
||||
CASE p.highest_severity
|
||||
WHEN 'critical' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
WHEN 'low' THEN 4
|
||||
ELSE 5
|
||||
END,
|
||||
p.vulnerability_count DESC`,
|
||||
|
||||
// Create function for automatic partition creation
|
||||
`CREATE OR REPLACE FUNCTION create_next_month_partitions()
|
||||
RETURNS void AS $$
|
||||
DECLARE
|
||||
next_month DATE := date_trunc('month', NOW() + INTERVAL '2 months');
|
||||
partition_name TEXT;
|
||||
start_date TEXT;
|
||||
end_date TEXT;
|
||||
BEGIN
|
||||
-- Download events partition
|
||||
partition_name := 'download_events_' || to_char(next_month, 'YYYY_MM');
|
||||
start_date := to_char(next_month, 'YYYY-MM-DD');
|
||||
end_date := to_char(next_month + INTERVAL '1 month', 'YYYY-MM-DD');
|
||||
|
||||
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF download_events FOR VALUES FROM (%L) TO (%L)',
|
||||
partition_name, start_date, end_date);
|
||||
|
||||
-- Audit log partition
|
||||
partition_name := 'audit_log_' || to_char(next_month, 'YYYY_MM');
|
||||
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF audit_log FOR VALUES FROM (%L) TO (%L)',
|
||||
partition_name, start_date, end_date);
|
||||
|
||||
RAISE NOTICE 'Created partitions for %', to_char(next_month, 'YYYY-MM');
|
||||
END;
|
||||
$$ LANGUAGE plpgsql`,
|
||||
}
|
||||
|
||||
for _, sql := range optimizations {
|
||||
if err := tx.Exec(sql).Error; err != nil {
|
||||
// Log warning but don't fail migration
|
||||
// Some optimizations might already exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createMySQLOptimizations adds MySQL-specific features
|
||||
func createMySQLOptimizations(tx *gorm.DB) error {
|
||||
optimizations := []string{
|
||||
// Create view for vulnerable packages
|
||||
`CREATE OR REPLACE VIEW v_vulnerable_packages AS
|
||||
SELECT
|
||||
r.name AS registry,
|
||||
p.name,
|
||||
p.version,
|
||||
p.vulnerability_count,
|
||||
p.highest_severity,
|
||||
p.last_scanned_at
|
||||
FROM packages p
|
||||
JOIN registries r ON p.registry_id = r.id
|
||||
WHERE p.vulnerability_count > 0 AND p.deleted_at IS NULL
|
||||
ORDER BY
|
||||
CASE p.highest_severity
|
||||
WHEN 'critical' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
WHEN 'low' THEN 4
|
||||
ELSE 5
|
||||
END,
|
||||
p.vulnerability_count DESC`,
|
||||
}
|
||||
|
||||
for _, sql := range optimizations {
|
||||
if err := tx.Exec(sql).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackFromV2 drops all V2 tables (USE WITH CAUTION!)
|
||||
func rollbackFromV2(tx *gorm.DB) error {
|
||||
// Drop in reverse order to respect foreign keys
|
||||
tables := []string{
|
||||
"audit_log",
|
||||
"download_stats_daily",
|
||||
"download_stats_hourly",
|
||||
"download_events",
|
||||
"cve_bypasses",
|
||||
"scan_results",
|
||||
"package_vulnerabilities",
|
||||
"vulnerabilities",
|
||||
"package_metadata",
|
||||
"packages",
|
||||
"registries",
|
||||
}
|
||||
|
||||
// Drop PostgreSQL-specific objects
|
||||
if tx.Dialector.Name() == "postgres" {
|
||||
tx.Exec("DROP VIEW IF EXISTS v_vulnerable_packages")
|
||||
tx.Exec("DROP FUNCTION IF EXISTS create_next_month_partitions()")
|
||||
}
|
||||
|
||||
// Drop MySQL-specific objects
|
||||
if tx.Dialector.Name() == "mysql" {
|
||||
tx.Exec("DROP VIEW IF EXISTS v_vulnerable_packages")
|
||||
}
|
||||
|
||||
// Drop all tables
|
||||
for _, table := range tables {
|
||||
if err := tx.Migrator().DropTable(table); err != nil {
|
||||
// Continue even if table doesn't exist
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseModel provides common fields for all models with audit trail
|
||||
type BaseModel struct {
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
UpdatedAt time.Time `gorm:"not null"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"` // Soft delete support (auto-generated index name per table)
|
||||
}
|
||||
|
||||
// RegistryModel represents package registries (normalized)
|
||||
// This eliminates repetition of "npm", "pypi", "go" across millions of rows
|
||||
type RegistryModel struct {
|
||||
ID int32 `gorm:"primaryKey;autoIncrement"`
|
||||
Name string `gorm:"uniqueIndex:idx_registry_name;not null;size:50"` // npm, pypi, go
|
||||
DisplayName string `gorm:"not null;size:100"` // NPM Registry, PyPI, Go Modules
|
||||
UpstreamURL string `gorm:"not null;size:512"` // https://registry.npmjs.org
|
||||
Enabled bool `gorm:"not null;default:true;index:idx_registry_enabled"`
|
||||
ScanByDefault bool `gorm:"not null;default:true"`
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (RegistryModel) TableName() string {
|
||||
return "registries"
|
||||
}
|
||||
|
||||
// PackageModel represents the core package data (optimized)
|
||||
type PackageModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RegistryID int32 `gorm:"not null;index:idx_package_registry_name_version,priority:1"` // Foreign key to registries
|
||||
Name string `gorm:"not null;size:255;index:idx_package_name;index:idx_package_registry_name_version,priority:2"`
|
||||
Version string `gorm:"not null;size:100;index:idx_package_registry_name_version,priority:3"`
|
||||
|
||||
// Storage information
|
||||
StorageKey string `gorm:"not null;uniqueIndex:idx_package_storage_key;size:512"`
|
||||
Size int64 `gorm:"not null;index:idx_package_size"` // For storage quota queries
|
||||
ChecksumMD5 string `gorm:"size:32;index:idx_package_md5"`
|
||||
ChecksumSHA256 string `gorm:"size:64;index:idx_package_sha256"`
|
||||
UpstreamURL string `gorm:"size:1024"`
|
||||
|
||||
// Cache management
|
||||
CachedAt time.Time `gorm:"not null;index:idx_package_cached_at"`
|
||||
LastAccessed time.Time `gorm:"not null;index:idx_package_last_accessed"` // For LRU eviction
|
||||
ExpiresAt *time.Time `gorm:"index:idx_package_expires_at"` // For cache invalidation
|
||||
AccessCount int64 `gorm:"not null;default:0;index:idx_package_access_count"` // Total access count (denormalized for performance)
|
||||
|
||||
// Security
|
||||
SecurityScanned bool `gorm:"not null;default:false;index:idx_package_security_scanned"`
|
||||
LastScannedAt *time.Time `gorm:"index:idx_package_last_scanned"`
|
||||
VulnerabilityCount int `gorm:"not null;default:0;index:idx_package_vuln_count"` // Denormalized for fast filtering
|
||||
HighestSeverity string `gorm:"size:20;index:idx_package_severity"` // critical, high, medium, low, none
|
||||
CriticalCount int `gorm:"not null;default:0"` // Count of critical vulnerabilities
|
||||
HighCount int `gorm:"not null;default:0"` // Count of high vulnerabilities
|
||||
ModerateCount int `gorm:"not null;default:0"` // Count of moderate vulnerabilities
|
||||
LowCount int `gorm:"not null;default:0"` // Count of low vulnerabilities
|
||||
|
||||
// Authentication
|
||||
RequiresAuth bool `gorm:"not null;default:false;index:idx_package_requires_auth"`
|
||||
AuthProvider string `gorm:"size:50;index:idx_package_auth_provider"` // github, gitlab, custom
|
||||
|
||||
BaseModel
|
||||
|
||||
// Relationships
|
||||
Registry RegistryModel `gorm:"foreignKey:RegistryID;constraint:OnUpdate:CASCADE,OnDelete:RESTRICT"`
|
||||
Metadata *PackageMetadataModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
ScanResults []ScanResultModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
Vulnerabilities []PackageVulnerabilityModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (PackageModel) TableName() string {
|
||||
return "packages"
|
||||
}
|
||||
|
||||
// BeforeCreate hook to set access count
|
||||
func (p *PackageModel) BeforeCreate(tx *gorm.DB) error {
|
||||
if p.AccessCount == 0 {
|
||||
p.AccessCount = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PackageMetadataModel stores structured package metadata (1:1 with packages)
|
||||
// Separated from main table to reduce row size and improve query performance
|
||||
type PackageMetadataModel struct {
|
||||
PackageID int64 `gorm:"primaryKey;not null"` // 1:1 relationship
|
||||
Author string `gorm:"size:255;index:idx_metadata_author"`
|
||||
License string `gorm:"size:100;index:idx_metadata_license"`
|
||||
Homepage string `gorm:"size:512"`
|
||||
Repository string `gorm:"size:512"`
|
||||
Description string `gorm:"type:text"`
|
||||
Keywords PostgresArray `gorm:"type:text"` // JSONB array for PostgreSQL, JSON for MySQL/SQLite
|
||||
RawMetadata JSONBField `gorm:"type:jsonb"` // Full metadata as JSONB (PostgreSQL) or JSON
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (PackageMetadataModel) TableName() string {
|
||||
return "package_metadata"
|
||||
}
|
||||
|
||||
// ScanResultModel represents security scan results (optimized)
|
||||
type ScanResultModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
PackageID int64 `gorm:"not null;index:idx_scan_package_scanner,priority:1"` // Foreign key
|
||||
Scanner string `gorm:"not null;size:50;index:idx_scan_scanner;index:idx_scan_package_scanner,priority:2"`
|
||||
ScannedAt time.Time `gorm:"not null;index:idx_scan_scanned_at"`
|
||||
Status string `gorm:"not null;size:20;index:idx_scan_status"` // success, failed, pending
|
||||
VulnCount int `gorm:"not null;default:0;index:idx_scan_vuln_count"`
|
||||
CriticalCount int `gorm:"not null;default:0"`
|
||||
HighCount int `gorm:"not null;default:0"`
|
||||
MediumCount int `gorm:"not null;default:0"`
|
||||
LowCount int `gorm:"not null;default:0"`
|
||||
ScanDuration int `gorm:"not null;default:0"` // milliseconds
|
||||
Details JSONBField `gorm:"type:jsonb"` // Scanner-specific details
|
||||
BaseModel
|
||||
|
||||
Package PackageModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (ScanResultModel) TableName() string {
|
||||
return "scan_results"
|
||||
}
|
||||
|
||||
// VulnerabilityModel represents unique vulnerabilities (normalized)
|
||||
type VulnerabilityModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
CVEID string `gorm:"uniqueIndex:idx_vuln_cve_id;not null;size:50"` // CVE-2021-12345
|
||||
Title string `gorm:"not null;size:512"`
|
||||
Description string `gorm:"type:text"`
|
||||
Severity string `gorm:"not null;size:20;index:idx_vuln_severity"` // critical, high, medium, low
|
||||
CVSS float32 `gorm:"index:idx_vuln_cvss"` // CVSS score for sorting
|
||||
PublishedAt time.Time `gorm:"not null;index:idx_vuln_published"`
|
||||
FixedVersion string `gorm:"size:100"` // First version where it's fixed
|
||||
References PostgresArray `gorm:"type:text"` // URLs to advisories
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (VulnerabilityModel) TableName() string {
|
||||
return "vulnerabilities"
|
||||
}
|
||||
|
||||
// PackageVulnerabilityModel is a many-to-many relationship between packages and vulnerabilities
|
||||
type PackageVulnerabilityModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
PackageID int64 `gorm:"not null;index:idx_pkg_vuln_package,priority:1;index:idx_pkg_vuln_composite,priority:1"`
|
||||
VulnerabilityID int64 `gorm:"not null;index:idx_pkg_vuln_vuln,priority:1;index:idx_pkg_vuln_composite,priority:2"`
|
||||
Scanner string `gorm:"not null;size:50;index:idx_pkg_vuln_scanner"`
|
||||
DetectedAt time.Time `gorm:"not null;index:idx_pkg_vuln_detected"`
|
||||
Bypassed bool `gorm:"not null;default:false;index:idx_pkg_vuln_bypassed"`
|
||||
BypassID *int64 `gorm:"index:idx_pkg_vuln_bypass_id"` // Reference to bypass if applicable
|
||||
BaseModel
|
||||
|
||||
Package PackageModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
Vulnerability VulnerabilityModel `gorm:"foreignKey:VulnerabilityID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (PackageVulnerabilityModel) TableName() string {
|
||||
return "package_vulnerabilities"
|
||||
}
|
||||
|
||||
// CVEBypassModel represents CVE bypass rules (improved)
|
||||
type CVEBypassModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
Type string `gorm:"not null;size:20;index:idx_bypass_type"` // cve, package, registry
|
||||
Target string `gorm:"not null;size:512;index:idx_bypass_target"` // CVE-ID, package name, etc.
|
||||
Reason string `gorm:"not null;type:text"`
|
||||
CreatedBy string `gorm:"not null;size:255;index:idx_bypass_created_by"`
|
||||
ExpiresAt time.Time `gorm:"not null;index:idx_bypass_expires_at"`
|
||||
NotifyOnExpiry bool `gorm:"not null;default:false"`
|
||||
Active bool `gorm:"not null;default:true;index:idx_bypass_active"`
|
||||
UsageCount int64 `gorm:"not null;default:0"` // How many times this bypass has been used
|
||||
LastUsedAt *time.Time `gorm:"index:idx_bypass_last_used"`
|
||||
|
||||
// Scope limiting (optional)
|
||||
RegistryID *int32 `gorm:"index:idx_bypass_registry"` // NULL = all registries
|
||||
PackageID *int64 `gorm:"index:idx_bypass_package"` // NULL = all packages
|
||||
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (CVEBypassModel) TableName() string {
|
||||
return "cve_bypasses"
|
||||
}
|
||||
|
||||
// DownloadEventModel represents raw download events (partitioned by month)
|
||||
// This table should use PostgreSQL partitioning or time-series DB features
|
||||
type DownloadEventModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
PackageID int64 `gorm:"not null;index:idx_download_package,priority:1"`
|
||||
RegistryID int32 `gorm:"not null;index:idx_download_registry"`
|
||||
DownloadedAt time.Time `gorm:"not null;index:idx_download_time;index:idx_download_package,priority:2"` // Partition key
|
||||
UserAgent string `gorm:"size:512"` // For analytics
|
||||
IPAddress string `gorm:"size:45;index:idx_download_ip"` // IPv6 support
|
||||
Authenticated bool `gorm:"not null;default:false"`
|
||||
Username string `gorm:"size:255;index:idx_download_user"`
|
||||
|
||||
// No BaseModel - this is append-only, no updates/deletes on individual rows
|
||||
// Partitioned tables handle cleanup via DROP PARTITION
|
||||
}
|
||||
|
||||
func (DownloadEventModel) TableName() string {
|
||||
return "download_events"
|
||||
}
|
||||
|
||||
// DownloadStatsHourlyModel represents pre-aggregated hourly statistics (partitioned)
|
||||
type DownloadStatsHourlyModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RegistryID int32 `gorm:"not null;index:idx_stats_hourly_composite,priority:1"`
|
||||
PackageID *int64 `gorm:"index:idx_stats_hourly_package"` // NULL = all packages in registry
|
||||
TimeBucket time.Time `gorm:"not null;index:idx_stats_hourly_composite,priority:2"` // Truncated to hour
|
||||
DownloadCount int64 `gorm:"not null;default:0"`
|
||||
UniqueIPs int64 `gorm:"not null;default:0"` // Unique downloaders
|
||||
AuthDownloads int64 `gorm:"not null;default:0"` // Authenticated downloads
|
||||
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (DownloadStatsHourlyModel) TableName() string {
|
||||
return "download_stats_hourly"
|
||||
}
|
||||
|
||||
// DownloadStatsDailyModel represents pre-aggregated daily statistics
|
||||
type DownloadStatsDailyModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
RegistryID int32 `gorm:"not null;index:idx_stats_daily_composite,priority:1"`
|
||||
PackageID *int64 `gorm:"index:idx_stats_daily_package"` // NULL = all packages in registry
|
||||
TimeBucket time.Time `gorm:"not null;index:idx_stats_daily_composite,priority:2"` // Truncated to day
|
||||
DownloadCount int64 `gorm:"not null;default:0"`
|
||||
UniqueIPs int64 `gorm:"not null;default:0"`
|
||||
AuthDownloads int64 `gorm:"not null;default:0"`
|
||||
TopUserAgents JSONBField `gorm:"type:jsonb"` // Top 10 user agents
|
||||
|
||||
BaseModel
|
||||
}
|
||||
|
||||
func (DownloadStatsDailyModel) TableName() string {
|
||||
return "download_stats_daily"
|
||||
}
|
||||
|
||||
// AuditLogModel tracks all important changes (optional, for compliance)
|
||||
type AuditLogModel struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement"`
|
||||
EntityType string `gorm:"not null;size:50;index:idx_audit_entity_type"` // package, bypass, registry
|
||||
EntityID int64 `gorm:"not null;index:idx_audit_entity_id"`
|
||||
Action string `gorm:"not null;size:20;index:idx_audit_action"` // create, update, delete
|
||||
Username string `gorm:"not null;size:255;index:idx_audit_username"`
|
||||
Timestamp time.Time `gorm:"not null;index:idx_audit_timestamp"`
|
||||
Changes JSONBField `gorm:"type:jsonb"` // Before/after values
|
||||
IPAddress string `gorm:"size:45"`
|
||||
UserAgent string `gorm:"size:512"`
|
||||
|
||||
// No BaseModel - append-only audit log
|
||||
}
|
||||
|
||||
func (AuditLogModel) TableName() string {
|
||||
return "audit_log"
|
||||
}
|
||||
|
||||
// JSONBField is a custom type for JSONB (PostgreSQL) / JSON (MySQL/SQLite)
|
||||
type JSONBField map[string]interface{}
|
||||
|
||||
func (j JSONBField) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (j *JSONBField) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
||||
// PostgresArray is a custom type for PostgreSQL arrays stored as JSON
|
||||
type PostgresArray []string
|
||||
|
||||
func (a PostgresArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(a)
|
||||
}
|
||||
|
||||
func (a *PostgresArray) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, a)
|
||||
}
|
||||
|
||||
// GetAllModels returns all models for GORM auto-migration
|
||||
func GetAllModels() []interface{} {
|
||||
return []interface{}{
|
||||
&RegistryModel{},
|
||||
&PackageModel{},
|
||||
&PackageMetadataModel{},
|
||||
&ScanResultModel{},
|
||||
&VulnerabilityModel{},
|
||||
&PackageVulnerabilityModel{},
|
||||
&CVEBypassModel{},
|
||||
&DownloadEventModel{},
|
||||
&DownloadStatsHourlyModel{},
|
||||
&DownloadStatsDailyModel{},
|
||||
&AuditLogModel{},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
package gormstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PartitionManager handles automatic partition creation and cleanup for PostgreSQL
|
||||
type PartitionManager struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPartitionManager creates a new partition manager
|
||||
func NewPartitionManager(db *gorm.DB) *PartitionManager {
|
||||
return &PartitionManager{db: db}
|
||||
}
|
||||
|
||||
// EnsurePartitions ensures required partitions exist for current and future months
|
||||
func (pm *PartitionManager) EnsurePartitions() error {
|
||||
// Check if we're using PostgreSQL
|
||||
if pm.db.Dialector.Name() != "postgres" {
|
||||
log.Debug().Msg("Partitioning only supported on PostgreSQL, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().Msg("Ensuring partitions exist")
|
||||
|
||||
// Create partitions for download_events
|
||||
if err := pm.ensureDownloadEventPartitions(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create partitions for audit_log
|
||||
if err := pm.ensureAuditLogPartitions(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up automatic partition creation
|
||||
if err := pm.createPartitionFunction(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create partition function (may already exist)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureDownloadEventPartitions creates download_events partitions
|
||||
func (pm *PartitionManager) ensureDownloadEventPartitions() error {
|
||||
// Check if table is already partitioned
|
||||
var isPartitioned bool
|
||||
err := pm.db.Raw(`
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_partitioned_table
|
||||
WHERE partrelid = 'download_events'::regclass
|
||||
)
|
||||
`).Scan(&isPartitioned).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !isPartitioned {
|
||||
log.Info().Msg("Converting download_events to partitioned table")
|
||||
|
||||
// Rename existing table
|
||||
if err := pm.db.Exec("ALTER TABLE IF EXISTS download_events RENAME TO download_events_old").Error; err != nil {
|
||||
log.Warn().Err(err).Msg("Could not rename old table (may not exist)")
|
||||
}
|
||||
|
||||
// Create partitioned table
|
||||
createTableSQL := `
|
||||
CREATE TABLE IF NOT EXISTS download_events (
|
||||
id BIGSERIAL,
|
||||
package_id BIGINT NOT NULL,
|
||||
registry_id INTEGER NOT NULL,
|
||||
downloaded_at TIMESTAMP NOT NULL,
|
||||
user_agent VARCHAR(512),
|
||||
ip_address VARCHAR(45),
|
||||
authenticated BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
username VARCHAR(255)
|
||||
) PARTITION BY RANGE (downloaded_at)
|
||||
`
|
||||
|
||||
if err := pm.db.Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create partitioned table: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msg("Created partitioned download_events table")
|
||||
}
|
||||
|
||||
// Create partitions for past 3 months, current month, and next 3 months
|
||||
now := time.Now()
|
||||
for i := -3; i <= 3; i++ {
|
||||
month := now.AddDate(0, i, 0)
|
||||
if err := pm.createDownloadEventPartition(month); err != nil {
|
||||
log.Error().Err(err).Time("month", month).Msg("Failed to create partition")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDownloadEventPartition creates a partition for a specific month
|
||||
func (pm *PartitionManager) createDownloadEventPartition(month time.Time) error {
|
||||
// Truncate to start of month
|
||||
startOfMonth := time.Date(month.Year(), month.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
endOfMonth := startOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
partitionName := fmt.Sprintf("download_events_%d_%02d", month.Year(), month.Month())
|
||||
|
||||
// Check if partition already exists
|
||||
var exists bool
|
||||
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = ?)", partitionName).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.Debug().Str("partition", partitionName).Msg("Partition already exists")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create partition
|
||||
createPartitionSQL := fmt.Sprintf(`
|
||||
CREATE TABLE %s PARTITION OF download_events
|
||||
FOR VALUES FROM ('%s') TO ('%s')
|
||||
`, partitionName, startOfMonth.Format("2006-01-02"), endOfMonth.Format("2006-01-02"))
|
||||
|
||||
if err := pm.db.Exec(createPartitionSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create partition %s: %w", partitionName, err)
|
||||
}
|
||||
|
||||
// Create indexes on partition
|
||||
indexSQL := []string{
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_package_idx ON %s(package_id, downloaded_at)", partitionName, partitionName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_registry_idx ON %s(registry_id)", partitionName, partitionName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_time_idx ON %s(downloaded_at)", partitionName, partitionName),
|
||||
}
|
||||
|
||||
for _, sql := range indexSQL {
|
||||
if err := pm.db.Exec(sql).Error; err != nil {
|
||||
log.Warn().Err(err).Str("sql", sql).Msg("Failed to create index")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("partition", partitionName).Msg("Created partition")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureAuditLogPartitions creates audit_log partitions
|
||||
func (pm *PartitionManager) ensureAuditLogPartitions() error {
|
||||
// Check if table exists
|
||||
var exists bool
|
||||
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = 'audit_log')").Scan(&exists).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// Create partitioned table
|
||||
createTableSQL := `
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id BIGSERIAL,
|
||||
entity_type VARCHAR(50) NOT NULL,
|
||||
entity_id BIGINT NOT NULL,
|
||||
action VARCHAR(20) NOT NULL,
|
||||
username VARCHAR(255) NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL,
|
||||
changes JSONB,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent VARCHAR(512)
|
||||
) PARTITION BY RANGE (timestamp)
|
||||
`
|
||||
|
||||
if err := pm.db.Exec(createTableSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create audit_log table: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msg("Created partitioned audit_log table")
|
||||
}
|
||||
|
||||
// Create partitions for past month, current month, and next 2 months
|
||||
now := time.Now()
|
||||
for i := -1; i <= 2; i++ {
|
||||
month := now.AddDate(0, i, 0)
|
||||
if err := pm.createAuditLogPartition(month); err != nil {
|
||||
log.Error().Err(err).Time("month", month).Msg("Failed to create audit partition")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAuditLogPartition creates a partition for a specific month
|
||||
func (pm *PartitionManager) createAuditLogPartition(month time.Time) error {
|
||||
startOfMonth := time.Date(month.Year(), month.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
endOfMonth := startOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
partitionName := fmt.Sprintf("audit_log_%d_%02d", month.Year(), month.Month())
|
||||
|
||||
// Check if partition already exists
|
||||
var exists bool
|
||||
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = ?)", partitionName).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create partition
|
||||
createPartitionSQL := fmt.Sprintf(`
|
||||
CREATE TABLE %s PARTITION OF audit_log
|
||||
FOR VALUES FROM ('%s') TO ('%s')
|
||||
`, partitionName, startOfMonth.Format("2006-01-02"), endOfMonth.Format("2006-01-02"))
|
||||
|
||||
if err := pm.db.Exec(createPartitionSQL).Error; err != nil {
|
||||
return fmt.Errorf("failed to create partition %s: %w", partitionName, err)
|
||||
}
|
||||
|
||||
// Create indexes
|
||||
indexSQL := []string{
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_entity_idx ON %s(entity_type, entity_id)", partitionName, partitionName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_user_idx ON %s(username)", partitionName, partitionName),
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_time_idx ON %s(timestamp)", partitionName, partitionName),
|
||||
}
|
||||
|
||||
for _, sql := range indexSQL {
|
||||
if err := pm.db.Exec(sql).Error; err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create audit index")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("partition", partitionName).Msg("Created audit partition")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPartitionFunction creates a PostgreSQL function for automatic partition creation
|
||||
func (pm *PartitionManager) createPartitionFunction() error {
|
||||
functionSQL := `
|
||||
CREATE OR REPLACE FUNCTION create_next_month_partitions()
|
||||
RETURNS void AS $$
|
||||
DECLARE
|
||||
next_month DATE := date_trunc('month', NOW() + INTERVAL '2 months');
|
||||
partition_name TEXT;
|
||||
start_date TEXT;
|
||||
end_date TEXT;
|
||||
BEGIN
|
||||
-- Create download_events partition
|
||||
partition_name := 'download_events_' || to_char(next_month, 'YYYY_MM');
|
||||
start_date := to_char(next_month, 'YYYY-MM-DD');
|
||||
end_date := to_char(next_month + INTERVAL '1 month', 'YYYY-MM-DD');
|
||||
|
||||
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF download_events FOR VALUES FROM (%L) TO (%L)',
|
||||
partition_name, start_date, end_date);
|
||||
|
||||
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(package_id, downloaded_at)',
|
||||
partition_name || '_package_idx', partition_name);
|
||||
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(registry_id)',
|
||||
partition_name || '_registry_idx', partition_name);
|
||||
|
||||
-- Create audit_log partition
|
||||
partition_name := 'audit_log_' || to_char(next_month, 'YYYY_MM');
|
||||
|
||||
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF audit_log FOR VALUES FROM (%L) TO (%L)',
|
||||
partition_name, start_date, end_date);
|
||||
|
||||
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(entity_type, entity_id)',
|
||||
partition_name || '_entity_idx', partition_name);
|
||||
|
||||
RAISE NOTICE 'Created partitions for %', to_char(next_month, 'YYYY-MM');
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
`
|
||||
|
||||
if err := pm.db.Exec(functionSQL).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msg("Created partition management function")
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupOldPartitions drops partitions older than the retention period
|
||||
func (pm *PartitionManager) CleanupOldPartitions(retentionMonths int) error {
|
||||
if pm.db.Dialector.Name() != "postgres" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cutoffDate := time.Now().AddDate(0, -retentionMonths, 0)
|
||||
cutoffPartition := fmt.Sprintf("%d_%02d", cutoffDate.Year(), cutoffDate.Month())
|
||||
|
||||
log.Info().
|
||||
Str("cutoff", cutoffPartition).
|
||||
Int("retention_months", retentionMonths).
|
||||
Msg("Cleaning up old partitions")
|
||||
|
||||
// Find and drop old download_events partitions
|
||||
var downloadPartitions []string
|
||||
err := pm.db.Raw(`
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE tablename LIKE 'download_events_%'
|
||||
AND tablename < 'download_events_' || ?
|
||||
`, cutoffPartition).Scan(&downloadPartitions).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, partition := range downloadPartitions {
|
||||
log.Info().Str("partition", partition).Msg("Dropping old partition")
|
||||
if err := pm.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", partition)).Error; err != nil {
|
||||
log.Error().Err(err).Str("partition", partition).Msg("Failed to drop partition")
|
||||
}
|
||||
}
|
||||
|
||||
// Find and drop old audit_log partitions
|
||||
var auditPartitions []string
|
||||
err = pm.db.Raw(`
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE tablename LIKE 'audit_log_%'
|
||||
AND tablename < 'audit_log_' || ?
|
||||
`, cutoffPartition).Scan(&auditPartitions).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, partition := range auditPartitions {
|
||||
log.Info().Str("partition", partition).Msg("Dropping old audit partition")
|
||||
if err := pm.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", partition)).Error; err != nil {
|
||||
log.Error().Err(err).Str("partition", partition).Msg("Failed to drop audit partition")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPartitionInfo returns information about current partitions
|
||||
func (pm *PartitionManager) GetPartitionInfo() (map[string]interface{}, error) {
|
||||
if pm.db.Dialector.Name() != "postgres" {
|
||||
return map[string]interface{}{"status": "not_applicable"}, nil
|
||||
}
|
||||
|
||||
info := make(map[string]interface{})
|
||||
|
||||
// Count download_events partitions
|
||||
var downloadCount int64
|
||||
pm.db.Raw("SELECT COUNT(*) FROM pg_tables WHERE tablename LIKE 'download_events_%'").Scan(&downloadCount)
|
||||
info["download_events_partitions"] = downloadCount
|
||||
|
||||
// Count audit_log partitions
|
||||
var auditCount int64
|
||||
pm.db.Raw("SELECT COUNT(*) FROM pg_tables WHERE tablename LIKE 'audit_log_%'").Scan(&auditCount)
|
||||
info["audit_log_partitions"] = auditCount
|
||||
|
||||
// Get partition sizes
|
||||
type PartitionSize struct {
|
||||
TableName string
|
||||
SizeMB float64
|
||||
}
|
||||
|
||||
var partitionSizes []PartitionSize
|
||||
pm.db.Raw(`
|
||||
SELECT
|
||||
tablename AS table_name,
|
||||
pg_total_relation_size(tablename::regclass) / 1024.0 / 1024.0 AS size_mb
|
||||
FROM pg_tables
|
||||
WHERE tablename LIKE 'download_events_%' OR tablename LIKE 'audit_log_%'
|
||||
ORDER BY size_mb DESC
|
||||
LIMIT 10
|
||||
`).Scan(&partitionSizes)
|
||||
|
||||
info["largest_partitions"] = partitionSizes
|
||||
|
||||
return info, nil
|
||||
}
|
||||
@@ -143,13 +143,18 @@ const (
|
||||
|
||||
// Stats represents metadata statistics
|
||||
type Stats struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Registry string `json:"registry"`
|
||||
TotalPackages int64 `json:"total_packages"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
TotalDownloads int64 `json:"total_downloads"`
|
||||
ScannedPackages int64 `json:"scanned_packages"`
|
||||
VulnerablePackages int64 `json:"vulnerable_packages"`
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Registry string `json:"registry"`
|
||||
TotalPackages int64 `json:"total_packages"`
|
||||
TotalSize int64 `json:"total_size"`
|
||||
TotalDownloads int64 `json:"total_downloads"`
|
||||
ScannedPackages int64 `json:"scanned_packages"`
|
||||
VulnerablePackages int64 `json:"vulnerable_packages"`
|
||||
BlockedPackages int64 `json:"blocked_packages"`
|
||||
CriticalVulnerabilities int64 `json:"critical_vulnerabilities"`
|
||||
HighVulnerabilities int64 `json:"high_vulnerabilities"`
|
||||
ModerateVulnerabilities int64 `json:"moderate_vulnerabilities"`
|
||||
LowVulnerabilities int64 `json:"low_vulnerabilities"`
|
||||
}
|
||||
|
||||
// TimeSeriesDataPoint represents a single data point in time-series
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+16
-3
@@ -148,6 +148,17 @@ func (h *Handler) handlePackagePage(ctx context.Context, w http.ResponseWriter,
|
||||
func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
packageName, version := extractPackageFileInfo(path)
|
||||
|
||||
// Make version unique by appending file type to avoid cache collisions
|
||||
// between .whl and .metadata files with same version
|
||||
cacheVersion := version
|
||||
if strings.HasSuffix(path, ".metadata") {
|
||||
cacheVersion = version + ".metadata"
|
||||
} else if strings.HasSuffix(path, ".whl") {
|
||||
cacheVersion = version + ".whl"
|
||||
} else if strings.HasSuffix(path, ".tar.gz") {
|
||||
cacheVersion = version + ".tar.gz"
|
||||
}
|
||||
|
||||
// Extract credentials from request
|
||||
credentials := h.credExtractor.Extract(r)
|
||||
credHash := h.credHasher.Hash(credentials)
|
||||
@@ -170,12 +181,13 @@ func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter,
|
||||
Str("path", path).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("cache_version", cacheVersion).
|
||||
Str("url", originalURL).
|
||||
Str("cred_hash", credHash).
|
||||
Bool("has_credentials", credentials != "").
|
||||
Msg("Handling PyPI package file request")
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, cacheVersion, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
// Prepare headers for upstream request
|
||||
headers := make(map[string]string)
|
||||
if credentials != "" {
|
||||
@@ -281,11 +293,12 @@ func isPackagePage(path string) bool {
|
||||
|
||||
// isPackageFile checks if the request is for a package file
|
||||
func isPackageFile(path string) bool {
|
||||
// Package files (not including .metadata files which need special handling)
|
||||
// Package files including .metadata files for PEP 658 support
|
||||
return strings.HasSuffix(path, ".whl") ||
|
||||
strings.HasSuffix(path, ".tar.gz") ||
|
||||
strings.HasSuffix(path, ".zip") ||
|
||||
strings.HasSuffix(path, ".egg")
|
||||
strings.HasSuffix(path, ".egg") ||
|
||||
strings.HasSuffix(path, ".metadata")
|
||||
}
|
||||
|
||||
// extractPackageName extracts package name from path
|
||||
|
||||
Reference in New Issue
Block a user