chore(schema): migrate to GORM V2 with multi-database support

- [x] Implement GORM V2 metadata store with SQLite, PostgreSQL, and MySQL support
- [x] Add database migration system using gormigrate for schema versioning
- [x] Create migration CLI tool with support for migrate, rollback, and status commands
- [x] Add Docker support for migration container (Dockerfile.migrate)
- [x] Implement automatic partition management for PostgreSQL time-series tables
- [x] Add background aggregation worker for download statistics
- [x] Support connection pooling configuration (max_open_conns, max_idle_conns, conn_max_lifetime)
- [x] Add blocking mechanism based on vulnerability thresholds in stats and handlers
- [x] Update Helm charts with migration init containers and multi-database configuration
- [x] Replace deprecated SQLite store with optimized GORM implementation
- [x] Add comprehensive integration tests for MySQL and PostgreSQL
- [x] Update frontend to display blocked packages and storage utilization
- [x] Add goreleaser configuration for migrate binary and container image
- [x] Update configuration examples with database backend options and recommendations
This commit is contained in:
2026-01-03 20:44:23 +00:00
parent b129279fb8
commit c0061b99e3
37 changed files with 5711 additions and 1222 deletions
+72 -7
View File
@@ -19,7 +19,7 @@ import (
"github.com/lukaszraczylo/gohoarder/pkg/health"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file"
metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite"
metagorm "github.com/lukaszraczylo/gohoarder/pkg/metadata/gormstore"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/network"
"github.com/lukaszraczylo/gohoarder/pkg/prewarming"
@@ -119,18 +119,67 @@ func (a *App) initializeComponents() error {
log.Info().Str("backend", a.config.Metadata.Backend).Msg("Initializing metadata store")
switch a.config.Metadata.Backend {
case "sqlite":
a.metadata, err = metasqlite.New(metasqlite.Config{
Path: a.config.Metadata.Connection,
WALMode: a.config.Metadata.SQLite.WALMode,
// Use GORM for SQLite
a.metadata, err = metagorm.NewV2(metagorm.Config{
Driver: "sqlite",
DSN: metagorm.BuildSQLiteDSN(a.config.Metadata.SQLite.Path, a.config.Metadata.SQLite.WALMode),
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
})
case "postgresql", "postgres":
// Use GORM for PostgreSQL
dsn := metagorm.BuildPostgresDSN(
a.config.Metadata.PostgreSQL.Host,
a.config.Metadata.PostgreSQL.Port,
a.config.Metadata.PostgreSQL.User,
a.config.Metadata.PostgreSQL.Password,
a.config.Metadata.PostgreSQL.Database,
getOrDefaultStr(a.config.Metadata.PostgreSQL.SSLMode, "disable"),
)
a.metadata, err = metagorm.NewV2(metagorm.Config{
Driver: "postgres",
DSN: dsn,
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
})
case "mysql", "mariadb":
// Use GORM for MySQL/MariaDB
dsn := metagorm.BuildMySQLDSN(
a.config.Metadata.MySQL.Host,
a.config.Metadata.MySQL.Port,
a.config.Metadata.MySQL.User,
a.config.Metadata.MySQL.Password,
a.config.Metadata.MySQL.Database,
getOrDefaultStr(a.config.Metadata.MySQL.Charset, "utf8mb4"),
)
a.metadata, err = metagorm.NewV2(metagorm.Config{
Driver: "mysql",
DSN: dsn,
MaxOpenConns: getOrDefault(a.config.Metadata.MaxOpenConns, 25),
MaxIdleConns: getOrDefault(a.config.Metadata.MaxIdleConns, 5),
ConnMaxLifetime: time.Duration(getOrDefault(a.config.Metadata.ConnMaxLifetime, 3600)) * time.Second,
LogLevel: getOrDefaultStr(a.config.Metadata.LogLevel, "warn"),
})
case "file":
// Keep file backend as-is for file-based metadata
a.metadata, err = metafile.New(metafile.Config{
Path: a.config.Metadata.Connection,
})
default:
a.metadata, err = metasqlite.New(metasqlite.Config{
Path: "gohoarder.db",
WALMode: false, // Default to DELETE mode for compatibility
// Default to SQLite with GORM
log.Warn().Str("backend", a.config.Metadata.Backend).Msg("Unknown metadata backend, defaulting to SQLite with GORM")
a.metadata, err = metagorm.NewV2(metagorm.Config{
Driver: "sqlite",
DSN: metagorm.BuildSQLiteDSN("gohoarder.db", false),
LogLevel: "warn",
})
}
if err != nil {
@@ -479,3 +528,19 @@ func (a *App) startAggregationWorker(ctx context.Context) {
}
}
}
// getOrDefault returns the value if it's non-zero, otherwise returns the default
func getOrDefault(value, defaultValue int) int {
if value == 0 {
return defaultValue
}
return value
}
// getOrDefaultStr returns the value if it's non-empty, otherwise returns the default
func getOrDefaultStr(value, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}
+27 -6
View File
@@ -142,6 +142,13 @@ func (a *App) handleListPackages(c *fiber.Ctx) error {
severityCounts[strings.ToUpper(vuln.Severity)]++
}
// Check if package should be blocked based on thresholds
isBlocked := false
if a.scanManager != nil {
blocked, _, _ := a.scanManager.CheckVulnerabilities(ctx, pkg.Registry, entry.originalName, pkg.Version)
isBlocked = blocked
}
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": true,
"status": scanResult.Status,
@@ -152,18 +159,21 @@ func (a *App) handleListPackages(c *fiber.Ctx) error {
"moderate": severityCounts["MODERATE"],
"low": severityCounts["LOW"],
},
"total": scanResult.VulnerabilityCount,
"total": scanResult.VulnerabilityCount,
"isBlocked": isBlocked,
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "pending",
"scanned": false,
"status": "pending",
"isBlocked": false,
}
}
} else {
pkgMap["vulnerabilities"] = map[string]interface{}{
"scanned": false,
"status": "not_scanned",
"scanned": false,
"status": "not_scanned",
"isBlocked": false,
}
}
@@ -351,8 +361,9 @@ func (a *App) handleStats(c *fiber.Ctx) error {
packages = []*metadata.Package{}
}
// Calculate per-registry breakdown (exclude metadata entries like "list", "latest")
// Calculate per-registry breakdown and blocked packages count
registryStats := make(map[string]map[string]interface{})
blockedCount := int64(0)
for _, pkg := range packages {
// Skip metadata entries (npm metadata pages, pypi pages, etc.)
@@ -371,6 +382,14 @@ func (a *App) handleStats(c *fiber.Ctx) error {
registryStats[pkg.Registry]["count"] = registryStats[pkg.Registry]["count"].(int) + 1
registryStats[pkg.Registry]["size"] = registryStats[pkg.Registry]["size"].(int64) + pkg.Size
registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount)
// Check if package is blocked (only if security scanning is enabled and package is scanned)
if a.config.Security.Enabled && a.scanManager != nil && pkg.SecurityScanned {
blocked, _, _ := a.scanManager.CheckVulnerabilities(ctx, pkg.Registry, pkg.Name, pkg.Version)
if blocked {
blockedCount++
}
}
}
// Combine statistics using database stats for accuracy
@@ -378,12 +397,14 @@ func (a *App) handleStats(c *fiber.Ctx) error {
"total_packages": cacheStats.TotalPackages,
"total_downloads": cacheStats.TotalDownloads,
"total_size": cacheStats.TotalSize,
"max_cache_size": a.config.Cache.MaxSizeBytes,
"cache_hits": cacheStats.TotalDownloads,
"cache_misses": 0, // TODO: Track cache misses
"cache_evictions": 0, // TODO: Track evictions
"cache_size": cacheStats.TotalSize,
"scanned_packages": cacheStats.ScannedPackages,
"vulnerable_packages": cacheStats.VulnerablePackages,
"blocked_packages": blockedCount,
}
// Convert registry stats to interface map
+14 -7
View File
@@ -203,9 +203,12 @@ func (m *Manager) getOrFetch(ctx context.Context, registry, name, version string
return nil, err
}
// Skip security scan wait for metadata entries (index pages, lists, etc.)
isMetadataEntry := version == "list" || version == "page" || version == "latest" || version == "metadata"
// Wait briefly for initial scan to complete if scanner is enabled
// This prevents serving vulnerable packages on first request
if m.scanner != nil {
if m.scanner != nil && !isMetadataEntry {
// Wait up to 30 seconds for scan to complete
scanCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
@@ -360,15 +363,19 @@ func (m *Manager) store(ctx context.Context, registry, name, version string, dat
Metadata: make(map[string]string),
}
// Save metadata
if err := m.metadata.SavePackage(ctx, pkg); err != nil {
// Clean up storage if metadata save fails
_ = m.storage.Delete(ctx, storageKey) // #nosec G104 -- Cleanup, error logged
return nil, err
// Save metadata (skip metadata entries like index pages, lists, etc.)
isMetadataEntry := version == "list" || version == "page" || version == "latest" || version == "metadata"
if !isMetadataEntry {
if err := m.metadata.SavePackage(ctx, pkg); err != nil {
// Clean up storage if metadata save fails
_ = m.storage.Delete(ctx, storageKey) // #nosec G104 -- Cleanup, error logged
return nil, err
}
}
// Scan package if scanner is enabled (run in background to not block cache operations)
if m.scanner != nil {
// Skip scanning metadata entries (index pages, lists, etc.)
if m.scanner != nil && !isMetadataEntry {
go func() {
scanCtx := context.Background()
var filePath string
+29 -4
View File
@@ -73,10 +73,17 @@ type SMBConfig struct {
// MetadataConfig contains metadata store configuration
type MetadataConfig struct {
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
Backend string `mapstructure:"backend" json:"backend"`
Connection string `mapstructure:"connection" json:"connection"`
SQLite SQLiteConfig `mapstructure:"sqlite" json:"sqlite"`
PostgreSQL PostgreSQLConfig `mapstructure:"postgresql" json:"postgresql"`
MySQL MySQLConfig `mapstructure:"mysql" json:"mysql"`
// GORM-specific settings
MaxOpenConns int `mapstructure:"max_open_conns" json:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns" json:"max_idle_conns"`
ConnMaxLifetime int `mapstructure:"conn_max_lifetime" json:"conn_max_lifetime"` // seconds
LogLevel string `mapstructure:"log_level" json:"log_level"` // "silent", "error", "warn", "info"
}
// SQLiteConfig contains SQLite-specific configuration
@@ -88,11 +95,22 @@ type SQLiteConfig struct {
// PostgreSQLConfig contains PostgreSQL-specific configuration
type PostgreSQLConfig struct {
Host string `mapstructure:"host" json:"host"`
Port int `mapstructure:"port" json:"port"`
Database string `mapstructure:"database" json:"database"`
User string `mapstructure:"user" json:"user"`
Password string `mapstructure:"password" json:"-"`
SSLMode string `mapstructure:"ssl_mode" json:"ssl_mode"`
Port int `mapstructure:"port" json:"port"`
}
// MySQLConfig contains MySQL/MariaDB-specific configuration
type MySQLConfig struct {
Host string `mapstructure:"host" json:"host"`
Port int `mapstructure:"port" json:"port"`
Database string `mapstructure:"database" json:"database"`
User string `mapstructure:"user" json:"user"`
Password string `mapstructure:"password" json:"-"` // Don't serialize
Charset string `mapstructure:"charset" json:"charset"`
ParseTime bool `mapstructure:"parse_time" json:"parse_time"`
}
// CacheConfig contains cache management configuration
@@ -415,9 +433,16 @@ func (c *Config) Validate() error {
}
// Validate metadata backend
validMetadataBackends := map[string]bool{"sqlite": true, "postgresql": true, "file": true}
validMetadataBackends := map[string]bool{
"sqlite": true,
"postgresql": true,
"postgres": true,
"mysql": true,
"mariadb": true,
"file": true,
}
if !validMetadataBackends[c.Metadata.Backend] {
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, file; got %s", c.Metadata.Backend)
return fmt.Errorf("metadata.backend must be one of: sqlite, postgresql, mysql, file; got %s", c.Metadata.Backend)
}
// Validate cache
@@ -0,0 +1,357 @@
package gormstore
import (
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
// AggregationWorker handles background aggregation of download statistics
type AggregationWorker struct {
db *gorm.DB
stopChan chan struct{}
ticker *time.Ticker
}
// NewAggregationWorker creates a new aggregation worker
func NewAggregationWorker(db *gorm.DB) *AggregationWorker {
return &AggregationWorker{
db: db,
stopChan: make(chan struct{}),
ticker: time.NewTicker(1 * time.Hour), // Run every hour
}
}
// Start begins the aggregation worker
func (w *AggregationWorker) Start() {
log.Info().Msg("Starting aggregation worker")
// Run immediately on start
if err := w.AggregateHourly(); err != nil {
log.Error().Err(err).Msg("Failed to run initial hourly aggregation")
}
for {
select {
case <-w.ticker.C:
if err := w.AggregateHourly(); err != nil {
log.Error().Err(err).Msg("Failed to aggregate hourly stats")
}
// Check if it's time for daily aggregation (run at midnight)
now := time.Now()
if now.Hour() == 0 {
if err := w.AggregateDaily(); err != nil {
log.Error().Err(err).Msg("Failed to aggregate daily stats")
}
}
case <-w.stopChan:
log.Info().Msg("Stopping aggregation worker")
w.ticker.Stop()
return
}
}
}
// Stop stops the aggregation worker
func (w *AggregationWorker) Stop() {
close(w.stopChan)
}
// AggregateHourly aggregates download events into hourly stats
func (w *AggregationWorker) AggregateHourly() error {
startTime := time.Now()
log.Debug().Msg("Starting hourly aggregation")
// Get dialect name
dialectName := w.db.Dialector.Name()
// Calculate cutoff time (aggregate events older than 5 minutes to avoid partial data)
cutoff := time.Now().Add(-5 * time.Minute).Truncate(time.Hour)
return w.db.Transaction(func(tx *gorm.DB) error {
var aggregateSQL string
switch dialectName {
case "postgres":
// PostgreSQL: Use date_trunc for time bucketing
aggregateSQL = `
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
de.registry_id,
de.package_id,
date_trunc('hour', de.downloaded_at) AS time_bucket,
COUNT(*) AS download_count,
COUNT(DISTINCT de.ip_address) AS unique_ips,
COUNT(*) FILTER (WHERE de.authenticated = true) AS auth_downloads,
NOW() AS created_at,
NOW() AS updated_at
FROM download_events de
WHERE de.downloaded_at < ?
GROUP BY de.registry_id, de.package_id, time_bucket
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
DO UPDATE SET
download_count = download_stats_hourly.download_count + EXCLUDED.download_count,
unique_ips = GREATEST(download_stats_hourly.unique_ips, EXCLUDED.unique_ips),
auth_downloads = download_stats_hourly.auth_downloads + EXCLUDED.auth_downloads,
updated_at = NOW()
`
case "mysql":
// MySQL: Use DATE_FORMAT for time bucketing
aggregateSQL = `
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
de.registry_id,
de.package_id,
DATE_FORMAT(de.downloaded_at, '%Y-%m-%d %H:00:00') AS time_bucket,
COUNT(*) AS download_count,
COUNT(DISTINCT de.ip_address) AS unique_ips,
SUM(CASE WHEN de.authenticated = true THEN 1 ELSE 0 END) AS auth_downloads,
NOW() AS created_at,
NOW() AS updated_at
FROM download_events de
WHERE de.downloaded_at < ?
GROUP BY de.registry_id, de.package_id, time_bucket
ON DUPLICATE KEY UPDATE
download_count = download_stats_hourly.download_count + VALUES(download_count),
unique_ips = GREATEST(download_stats_hourly.unique_ips, VALUES(unique_ips)),
auth_downloads = download_stats_hourly.auth_downloads + VALUES(auth_downloads),
updated_at = NOW()
`
default: // SQLite
// SQLite: Use strftime for time bucketing
// Note: SQLite doesn't support UPSERT as elegantly, need to handle separately
aggregateSQL = `
INSERT OR REPLACE INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
de.registry_id,
de.package_id,
strftime('%Y-%m-%d %H:00:00', de.downloaded_at) AS time_bucket,
COUNT(*) AS download_count,
COUNT(DISTINCT de.ip_address) AS unique_ips,
SUM(CASE WHEN de.authenticated = 1 THEN 1 ELSE 0 END) AS auth_downloads,
datetime('now') AS created_at,
datetime('now') AS updated_at
FROM download_events de
WHERE de.downloaded_at < ?
GROUP BY de.registry_id, de.package_id, time_bucket
`
}
// Execute aggregation
if err := tx.Exec(aggregateSQL, cutoff).Error; err != nil {
return err
}
// Delete aggregated events (older than 24 hours to keep recent data for debugging)
deleteOlder := time.Now().Add(-24 * time.Hour)
deleteResult := tx.Exec("DELETE FROM download_events WHERE downloaded_at < ?", deleteOlder)
if deleteResult.Error != nil {
return deleteResult.Error
}
// Also update package-level stats (NULL package_id = registry totals)
var registryAggSQL string
if dialectName == "postgres" {
registryAggSQL = `
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
registry_id,
NULL as package_id,
time_bucket,
SUM(download_count) as download_count,
SUM(unique_ips) as unique_ips,
SUM(auth_downloads) as auth_downloads,
NOW() as created_at,
NOW() as updated_at
FROM download_stats_hourly
WHERE package_id IS NOT NULL
GROUP BY registry_id, time_bucket
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
DO UPDATE SET
download_count = EXCLUDED.download_count,
unique_ips = EXCLUDED.unique_ips,
auth_downloads = EXCLUDED.auth_downloads,
updated_at = NOW()
`
} else if dialectName == "mysql" {
registryAggSQL = `
INSERT INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
registry_id,
NULL as package_id,
time_bucket,
SUM(download_count) as download_count,
SUM(unique_ips) as unique_ips,
SUM(auth_downloads) as auth_downloads,
NOW() as created_at,
NOW() as updated_at
FROM download_stats_hourly
WHERE package_id IS NOT NULL
GROUP BY registry_id, time_bucket
ON DUPLICATE KEY UPDATE
download_count = VALUES(download_count),
unique_ips = VALUES(unique_ips),
auth_downloads = VALUES(auth_downloads),
updated_at = NOW()
`
} else {
// SQLite
registryAggSQL = `
INSERT OR REPLACE INTO download_stats_hourly (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, created_at, updated_at)
SELECT
registry_id,
NULL as package_id,
time_bucket,
SUM(download_count) as download_count,
SUM(unique_ips) as unique_ips,
SUM(auth_downloads) as auth_downloads,
datetime('now') as created_at,
datetime('now') as updated_at
FROM download_stats_hourly
WHERE package_id IS NOT NULL
GROUP BY registry_id, time_bucket
`
}
if err := tx.Exec(registryAggSQL).Error; err != nil {
log.Warn().Err(err).Msg("Failed to aggregate registry totals (continuing anyway)")
}
elapsed := time.Since(startTime)
log.Info().
Int64("deleted_events", deleteResult.RowsAffected).
Dur("duration", elapsed).
Msg("Completed hourly aggregation")
return nil
})
}
// AggregateDaily aggregates hourly stats into daily stats
func (w *AggregationWorker) AggregateDaily() error {
startTime := time.Now()
log.Debug().Msg("Starting daily aggregation")
dialectName := w.db.Dialector.Name()
// Aggregate yesterday's data
yesterday := time.Now().AddDate(0, 0, -1).Truncate(24 * time.Hour)
dayEnd := yesterday.Add(24 * time.Hour)
return w.db.Transaction(func(tx *gorm.DB) error {
var aggregateSQL string
switch dialectName {
case "postgres":
aggregateSQL = `
INSERT INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
SELECT
registry_id,
package_id,
date_trunc('day', time_bucket) AS time_bucket,
SUM(download_count) AS download_count,
MAX(unique_ips) AS unique_ips,
SUM(auth_downloads) AS auth_downloads,
'{}' AS top_user_agents,
NOW() AS created_at,
NOW() AS updated_at
FROM download_stats_hourly
WHERE time_bucket >= ? AND time_bucket < ?
GROUP BY registry_id, package_id, date_trunc('day', time_bucket)
ON CONFLICT (registry_id, COALESCE(package_id, 0), time_bucket)
DO UPDATE SET
download_count = EXCLUDED.download_count,
unique_ips = EXCLUDED.unique_ips,
auth_downloads = EXCLUDED.auth_downloads,
updated_at = NOW()
`
case "mysql":
aggregateSQL = `
INSERT INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
SELECT
registry_id,
package_id,
DATE_FORMAT(time_bucket, '%Y-%m-%d 00:00:00') AS time_bucket,
SUM(download_count) AS download_count,
MAX(unique_ips) AS unique_ips,
SUM(auth_downloads) AS auth_downloads,
'{}' AS top_user_agents,
NOW() AS created_at,
NOW() AS updated_at
FROM download_stats_hourly
WHERE time_bucket >= ? AND time_bucket < ?
GROUP BY registry_id, package_id, DATE_FORMAT(time_bucket, '%Y-%m-%d 00:00:00')
ON DUPLICATE KEY UPDATE
download_count = VALUES(download_count),
unique_ips = VALUES(unique_ips),
auth_downloads = VALUES(auth_downloads),
updated_at = NOW()
`
default: // SQLite
aggregateSQL = `
INSERT OR REPLACE INTO download_stats_daily (registry_id, package_id, time_bucket, download_count, unique_ips, auth_downloads, top_user_agents, created_at, updated_at)
SELECT
registry_id,
package_id,
date(time_bucket) AS time_bucket,
SUM(download_count) AS download_count,
MAX(unique_ips) AS unique_ips,
SUM(auth_downloads) AS auth_downloads,
'{}' AS top_user_agents,
datetime('now') AS created_at,
datetime('now') AS updated_at
FROM download_stats_hourly
WHERE time_bucket >= ? AND time_bucket < ?
GROUP BY registry_id, package_id, date(time_bucket)
`
}
if err := tx.Exec(aggregateSQL, yesterday, dayEnd).Error; err != nil {
return err
}
// Delete old hourly stats (keep last 7 days)
deleteOlder := time.Now().AddDate(0, 0, -7)
deleteResult := tx.Exec("DELETE FROM download_stats_hourly WHERE time_bucket < ?", deleteOlder)
if deleteResult.Error != nil {
return deleteResult.Error
}
elapsed := time.Since(startTime)
log.Info().
Int64("deleted_hourly_stats", deleteResult.RowsAffected).
Dur("duration", elapsed).
Msg("Completed daily aggregation")
return nil
})
}
// UpdatePackageAccessCounts synchronizes package access_count from download stats
func (w *AggregationWorker) UpdatePackageAccessCounts() error {
log.Debug().Msg("Updating package access counts")
// Update from download_stats_hourly (sum all-time downloads per package)
updateSQL := `
UPDATE packages p
SET access_count = COALESCE((
SELECT SUM(download_count)
FROM download_stats_hourly dsh
WHERE dsh.package_id = p.id
), 0)
`
if err := w.db.Exec(updateSQL).Error; err != nil {
return err
}
log.Info().Msg("Updated package access counts")
return nil
}
+78
View File
@@ -0,0 +1,78 @@
package gormstore
import (
"fmt"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
)
// Config holds GORM store configuration
type Config struct {
// Database connection
Driver string // "sqlite", "postgres", "mysql"
DSN string // Data Source Name
// Connection pool
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
// GORM settings
LogLevel string // "silent", "error", "warn", "info"
}
// Validate validates the configuration
func (c *Config) Validate() error {
if c.Driver == "" {
return errors.New(errors.ErrCodeInvalidConfig, "driver is required")
}
if c.DSN == "" {
return errors.New(errors.ErrCodeInvalidConfig, "DSN is required")
}
// Set defaults
if c.MaxOpenConns == 0 {
c.MaxOpenConns = 25
}
if c.MaxIdleConns == 0 {
c.MaxIdleConns = 5
}
if c.ConnMaxLifetime == 0 {
c.ConnMaxLifetime = time.Hour
}
if c.LogLevel == "" {
c.LogLevel = "warn"
}
return nil
}
// BuildPostgresDSN builds PostgreSQL DSN from structured config
func BuildPostgresDSN(host string, port int, user, password, database, sslmode string) string {
if sslmode == "" {
sslmode = "disable"
}
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, database, sslmode)
}
// BuildMySQLDSN builds MySQL/MariaDB DSN from structured config
func BuildMySQLDSN(host string, port int, user, password, database, charset string) string {
if charset == "" {
charset = "utf8mb4"
}
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
user, password, host, port, database, charset)
}
// BuildSQLiteDSN builds SQLite DSN with pragmas
func BuildSQLiteDSN(path string, walMode bool) string {
if path == "" {
path = "gohoarder.db"
}
if walMode {
return fmt.Sprintf("%s?_journal_mode=WAL&_busy_timeout=5000&_synchronous=NORMAL&_cache_size=2000", path)
}
return fmt.Sprintf("%s?_journal_mode=DELETE&_busy_timeout=5000&_synchronous=NORMAL", path)
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,279 @@
//go:build integration
// +build integration
package gormstore
import (
"context"
"testing"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/stretchr/testify/suite"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mysql"
"github.com/testcontainers/testcontainers-go/wait"
)
// MySQLV2IntegrationTestSuite embeds the V2 test suite with MySQL container
type MySQLV2IntegrationTestSuite struct {
GORMStoreV2TestSuite
container *mysql.MySQLContainer
}
// SetupSuite runs once before all tests
func (s *MySQLV2IntegrationTestSuite) SetupSuite() {
ctx := context.Background()
// Start MySQL container
container, err := mysql.RunContainer(ctx,
testcontainers.WithImage("mysql:8.0"),
mysql.WithDatabase("testdb"),
mysql.WithUsername("testuser"),
mysql.WithPassword("testpass"),
testcontainers.WithWaitStrategy(
wait.ForLog("port: 3306 MySQL Community Server").
WithOccurrence(1).
WithStartupTimeout(60*time.Second),
),
)
s.Require().NoError(err)
s.container = container
}
// TearDownSuite runs once after all tests
func (s *MySQLV2IntegrationTestSuite) TearDownSuite() {
if s.container != nil {
ctx := context.Background()
err := s.container.Terminate(ctx)
s.NoError(err)
}
}
// SetupTest runs before each test
func (s *MySQLV2IntegrationTestSuite) SetupTest() {
s.ctx = context.Background()
// Get connection string from container
connStr, err := s.container.ConnectionString(s.ctx)
s.Require().NoError(err)
// Create GORM store with MySQL
cfg := Config{
Driver: "mysql",
DSN: connStr,
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: 3600 * time.Second,
LogLevel: "silent",
}
s.store, err = NewV2(cfg)
s.Require().NoError(err)
s.Require().NotNil(s.store)
}
// TearDownTest runs after each test
func (s *MySQLV2IntegrationTestSuite) TearDownTest() {
if s.store != nil {
// Clean up all tables for next test
tables := []string{
"download_events",
"download_stats_hourly",
"download_stats_daily",
"audit_log",
"cve_bypasses",
"scan_results",
"package_vulnerabilities",
"vulnerabilities",
"package_metadata",
"packages",
"registries",
}
for _, table := range tables {
s.store.db.Exec("TRUNCATE TABLE " + table)
}
// Re-seed default registries after truncate
defaultRegistries := []RegistryModel{
{Name: "npm", DisplayName: "NPM Registry", UpstreamURL: "https://registry.npmjs.org", Enabled: true, ScanByDefault: true},
{Name: "pypi", DisplayName: "PyPI", UpstreamURL: "https://pypi.org", Enabled: true, ScanByDefault: true},
{Name: "go", DisplayName: "Go Modules", UpstreamURL: "https://proxy.golang.org", Enabled: true, ScanByDefault: true},
}
for _, reg := range defaultRegistries {
s.store.db.Create(&reg)
}
// Rebuild registry cache
s.store.rebuildRegistryCache()
s.store.Close()
}
}
// TestMySQLV2IntegrationTestSuite runs the integration test suite with MySQL
func TestMySQLV2IntegrationTestSuite(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
suite.Run(t, new(MySQLV2IntegrationTestSuite))
}
// Test_MySQLV2_SpecificFeatures tests MySQL-specific features
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_SpecificFeatures() {
// Test that we're actually using MySQL
var version string
err := s.store.db.Raw("SELECT VERSION()").Scan(&version).Error
s.NoError(err)
s.Contains(version, "MySQL")
}
// Test_MySQLV2_NoPartitioning tests that partition manager is nil for MySQL
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_NoPartitioning() {
// MySQL doesn't use our partition manager (uses native partitioning differently)
s.Nil(s.store.partitionManager)
}
// Test_MySQLV2_HighConcurrency tests MySQL's concurrent write support
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_HighConcurrency() {
pkg := &metadata.Package{
Registry: "npm",
Name: "concurrent-test",
Version: "1.0.0",
StorageKey: "npm/concurrent-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// MySQL can handle concurrent writes (with InnoDB row-level locking)
concurrency := 15
done := make(chan bool, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < concurrency; i++ {
<-done
}
// Verify all updates succeeded
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
s.Equal(int64(concurrency), retrieved.DownloadCount)
}
// Test_MySQLV2_JSON tests MySQL JSON functionality
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_JSON() {
metadata := map[string]interface{}{
"author": "Test Author",
"license": "MIT",
"description": "Test package",
"keywords": []interface{}{"test", "mysql", "json"},
}
pkg := &metadata.Package{
Registry: "npm",
Name: "json-test",
Version: "1.0.0",
StorageKey: "npm/json-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
Metadata: metadata,
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Retrieve and verify JSON data
retrieved, err := s.store.GetPackage(s.ctx, "npm", "json-test", "1.0.0")
s.NoError(err)
s.NotNil(retrieved.Metadata)
s.Equal("MIT", retrieved.Metadata["license"])
s.Equal("Test Author", retrieved.Metadata["author"])
}
// Test_MySQLV2_TransactionRollback tests MySQL transaction rollback
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_TransactionRollback() {
pkg := &metadata.Package{
Registry: "npm",
Name: "tx-test",
Version: "1.0.0",
StorageKey: "npm/tx-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Try to update with invalid data that should trigger rollback
err = s.store.db.Transaction(func(tx *gorm.DB) error {
// First update succeeds
result := tx.Model(&PackageModel{}).
Where("registry_id = ? AND name = ? AND version = ?",
s.store.registryCache["npm"], "tx-test", "1.0.0").
Update("access_count", gorm.Expr("access_count + ?", 1))
if result.Error != nil {
return result.Error
}
// Second operation fails (invalid foreign key)
invalidModel := &PackageModel{
RegistryID: 9999, // Non-existent registry
Name: "invalid",
Version: "1.0.0",
StorageKey: "invalid",
Size: 100,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
return tx.Create(invalidModel).Error
})
// Transaction should fail
s.Error(err)
// Verify first update was rolled back
retrieved, err := s.store.GetPackage(s.ctx, "npm", "tx-test", "1.0.0")
s.NoError(err)
s.Equal(int64(0), retrieved.DownloadCount) // Should still be 0, not 1
}
// Test_MySQLV2_CharacterSet tests MySQL UTF-8 support
func (s *MySQLV2IntegrationTestSuite) Test_MySQLV2_CharacterSet() {
// Test package with Unicode characters
pkg := &metadata.Package{
Registry: "npm",
Name: "unicode-test-世界-🚀",
Version: "1.0.0",
StorageKey: "npm/unicode-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
Metadata: map[string]interface{}{
"description": "Test with emoji 🎉 and Chinese 中文",
},
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Retrieve and verify Unicode data preserved
retrieved, err := s.store.GetPackage(s.ctx, "npm", "unicode-test-世界-🚀", "1.0.0")
s.NoError(err)
s.Equal("unicode-test-世界-🚀", retrieved.Name)
s.Contains(retrieved.Metadata["description"], "🎉")
s.Contains(retrieved.Metadata["description"], "中文")
}
@@ -0,0 +1,202 @@
//go:build integration
// +build integration
package gormstore
import (
"context"
"testing"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/stretchr/testify/suite"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
// PostgresV2IntegrationTestSuite embeds the V2 test suite with PostgreSQL container
type PostgresV2IntegrationTestSuite struct {
GORMStoreV2TestSuite
container *postgres.PostgresContainer
}
// SetupSuite runs once before all tests
func (s *PostgresV2IntegrationTestSuite) SetupSuite() {
ctx := context.Background()
// Start PostgreSQL container
container, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:16-alpine"),
postgres.WithDatabase("testdb"),
postgres.WithUsername("testuser"),
postgres.WithPassword("testpass"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(60*time.Second),
),
)
s.Require().NoError(err)
s.container = container
}
// TearDownSuite runs once after all tests
func (s *PostgresV2IntegrationTestSuite) TearDownSuite() {
if s.container != nil {
ctx := context.Background()
err := s.container.Terminate(ctx)
s.NoError(err)
}
}
// SetupTest runs before each test
func (s *PostgresV2IntegrationTestSuite) SetupTest() {
s.ctx = context.Background()
// Get connection string from container
connStr, err := s.container.ConnectionString(s.ctx)
s.Require().NoError(err)
// Create GORM store with PostgreSQL
cfg := Config{
Driver: "postgres",
DSN: connStr + "sslmode=disable",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: 3600 * time.Second,
LogLevel: "silent",
}
s.store, err = NewV2(cfg)
s.Require().NoError(err)
s.Require().NotNil(s.store)
}
// TearDownTest runs after each test
func (s *PostgresV2IntegrationTestSuite) TearDownTest() {
if s.store != nil {
// Clean up all tables for next test
tables := []string{
"download_events",
"download_stats_hourly",
"download_stats_daily",
"audit_log",
"cve_bypasses",
"scan_results",
"package_vulnerabilities",
"vulnerabilities",
"package_metadata",
"packages",
}
for _, table := range tables {
s.store.db.Exec("TRUNCATE TABLE " + table + " CASCADE")
}
s.store.Close()
}
}
// TestPostgresV2IntegrationTestSuite runs the integration test suite with PostgreSQL
func TestPostgresV2IntegrationTestSuite(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration tests in short mode")
}
suite.Run(t, new(PostgresV2IntegrationTestSuite))
}
// Test_PostgresV2_SpecificFeatures tests PostgreSQL-specific features
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_SpecificFeatures() {
// Test that we're actually using PostgreSQL
var version string
err := s.store.db.Raw("SELECT version()").Scan(&version).Error
s.NoError(err)
s.Contains(version, "PostgreSQL")
}
// Test_PostgresV2_Partitioning tests partition manager
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_Partitioning() {
s.NotNil(s.store.partitionManager)
// Get partition info
info, err := s.store.partitionManager.GetPartitionInfo()
s.NoError(err)
s.NotNil(info)
// Should have created partitions
downloadPartitions := info["download_events_partitions"].(int64)
s.Greater(downloadPartitions, int64(0))
auditPartitions := info["audit_log_partitions"].(int64)
s.Greater(auditPartitions, int64(0))
}
// Test_PostgresV2_HighConcurrency tests PostgreSQL's excellent concurrent write support
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_HighConcurrency() {
pkg := &metadata.Package{
Registry: "npm",
Name: "concurrent-test",
Version: "1.0.0",
StorageKey: "npm/concurrent-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// PostgreSQL can handle many concurrent writes
concurrency := 20
done := make(chan bool, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < concurrency; i++ {
<-done
}
// Verify all updates succeeded
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
s.Equal(int64(concurrency), retrieved.DownloadCount)
}
// Test_PostgresV2_JSONB tests PostgreSQL JSONB functionality
func (s *PostgresV2IntegrationTestSuite) Test_PostgresV2_JSONB() {
metadata := map[string]interface{}{
"author": "Test Author",
"license": "MIT",
"description": "Test package",
"keywords": []interface{}{"test", "postgres", "jsonb"},
}
pkg := &metadata.Package{
Registry: "npm",
Name: "jsonb-test",
Version: "1.0.0",
StorageKey: "npm/jsonb-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
Metadata: metadata,
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Retrieve and verify JSONB data
retrieved, err := s.store.GetPackage(s.ctx, "npm", "jsonb-test", "1.0.0")
s.NoError(err)
s.NotNil(retrieved.Metadata)
s.Equal("MIT", retrieved.Metadata["license"])
s.Equal("Test Author", retrieved.Metadata["author"])
}
+871
View File
@@ -0,0 +1,871 @@
package gormstore
import (
"context"
"fmt"
"testing"
"time"
"github.com/lukaszraczylo/gohoarder/pkg/metadata"
"github.com/stretchr/testify/suite"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// GORMStoreV2TestSuite is the test suite for V2 GORM implementation
type GORMStoreV2TestSuite struct {
suite.Suite
db *gorm.DB
store *GORMStoreV2
ctx context.Context
}
// SetupSuite runs once before all tests
func (s *GORMStoreV2TestSuite) SetupSuite() {
// Use in-memory SQLite for fast tests
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
s.Require().NoError(err)
s.db = db
}
// SetupTest runs before each test
func (s *GORMStoreV2TestSuite) SetupTest() {
s.ctx = context.Background()
// Create fresh store with V2 schema
cfg := Config{
Driver: "sqlite",
DSN: "file::memory:?cache=shared",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: 3600 * time.Second,
LogLevel: "silent",
}
store, err := NewV2(cfg)
s.Require().NoError(err)
s.Require().NotNil(store)
s.store = store
}
// TearDownTest runs after each test
func (s *GORMStoreV2TestSuite) TearDownTest() {
if s.store != nil {
// Clean up all tables for next test
tables := []string{
"audit_log",
"download_stats_daily",
"download_stats_hourly",
"download_events",
"cve_bypasses",
"scan_results",
"package_vulnerabilities",
"vulnerabilities",
"package_metadata",
"packages",
}
for _, table := range tables {
s.store.db.Exec(fmt.Sprintf("DELETE FROM %s", table))
}
s.store.Close()
}
}
// TestGORMStoreV2TestSuite runs the test suite
func TestGORMStoreV2TestSuite(t *testing.T) {
suite.Run(t, new(GORMStoreV2TestSuite))
}
// Test_V2_SavePackage_Success tests saving a package
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_Success() {
pkg := &metadata.Package{
Registry: "npm",
Name: "test-package",
Version: "1.0.0",
StorageKey: "npm/test-package/1.0.0.tgz",
Size: 12345,
ChecksumMD5: "abc123",
ChecksumSHA256: "def456",
UpstreamURL: "https://registry.npmjs.org/test-package",
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Verify package was saved
retrieved, err := s.store.GetPackage(s.ctx, "npm", "test-package", "1.0.0")
s.NoError(err)
s.NotNil(retrieved)
s.Equal("npm", retrieved.Registry)
s.Equal("test-package", retrieved.Name)
s.Equal("1.0.0", retrieved.Version)
s.Equal(int64(12345), retrieved.Size)
}
// Test_V2_SavePackage_WithMetadata tests saving package with metadata
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_WithMetadata() {
metadataMap := map[string]string{
"author": "Test Author",
"license": "MIT",
"homepage": "https://example.com",
"description": "Test package description",
}
pkg := &metadata.Package{
Registry: "npm",
Name: "meta-package",
Version: "2.0.0",
StorageKey: "npm/meta-package/2.0.0.tgz",
Size: 5000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
Metadata: metadataMap,
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Verify metadata was saved in separate table
var pkgMetadata PackageMetadataModel
err = s.store.db.Where("package_id = (SELECT id FROM packages WHERE name = ?)", "meta-package").
First(&pkgMetadata).Error
s.NoError(err)
s.Equal("Test Author", pkgMetadata.Author)
s.Equal("MIT", pkgMetadata.License)
s.Equal("https://example.com", pkgMetadata.Homepage)
}
// Test_V2_SavePackage_Upsert tests update on conflict
func (s *GORMStoreV2TestSuite) Test_V2_SavePackage_Upsert() {
// Save initial package
pkg := &metadata.Package{
Registry: "npm",
Name: "upsert-test",
Version: "1.0.0",
StorageKey: "npm/upsert-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Update same package
pkg.Size = 2000
pkg.ChecksumMD5 = "updated"
err = s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Verify updated
retrieved, err := s.store.GetPackage(s.ctx, "npm", "upsert-test", "1.0.0")
s.NoError(err)
s.Equal(int64(2000), retrieved.Size)
s.Equal("updated", retrieved.ChecksumMD5)
}
// Test_V2_GetPackage_NotFound tests getting non-existent package
func (s *GORMStoreV2TestSuite) Test_V2_GetPackage_NotFound() {
_, err := s.store.GetPackage(s.ctx, "npm", "nonexistent", "1.0.0")
s.Error(err)
s.Contains(err.Error(), "not found")
}
// Test_V2_DeletePackage_Success tests soft delete
func (s *GORMStoreV2TestSuite) Test_V2_DeletePackage_Success() {
// Save package
pkg := &metadata.Package{
Registry: "npm",
Name: "delete-test",
Version: "1.0.0",
StorageKey: "npm/delete-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Delete package (soft delete)
err = s.store.DeletePackage(s.ctx, "npm", "delete-test", "1.0.0")
s.NoError(err)
// Verify deleted (should not be found)
_, err = s.store.GetPackage(s.ctx, "npm", "delete-test", "1.0.0")
s.Error(err)
// Verify soft delete (deleted_at set)
var count int64
s.store.db.Unscoped().Model(&PackageModel{}).
Where("name = ?", "delete-test").
Count(&count)
s.Equal(int64(1), count) // Still in DB, just soft deleted
}
// Test_V2_ListPackages_All tests listing all packages
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_All() {
// Create multiple packages
for i := 0; i < 5; i++ {
pkg := &metadata.Package{
Registry: "npm",
Name: fmt.Sprintf("package-%d", i),
Version: "1.0.0",
StorageKey: fmt.Sprintf("npm/package-%d/1.0.0.tgz", i),
Size: int64(i * 1000),
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
}
// List all packages
packages, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{})
s.NoError(err)
s.Len(packages, 5)
}
// Test_V2_ListPackages_FilterByRegistry tests filtering by registry
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_FilterByRegistry() {
// Create packages in different registries
registries := []string{"npm", "pypi", "go"}
for _, reg := range registries {
pkg := &metadata.Package{
Registry: reg,
Name: "test-package",
Version: "1.0.0",
StorageKey: fmt.Sprintf("%s/test-package/1.0.0", reg),
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
}
// Filter by npm registry
packages, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
Registry: "npm",
})
s.NoError(err)
s.Len(packages, 1)
s.Equal("npm", packages[0].Registry)
}
// Test_V2_ListPackages_Pagination tests pagination
func (s *GORMStoreV2TestSuite) Test_V2_ListPackages_Pagination() {
// Create 10 packages
for i := 0; i < 10; i++ {
pkg := &metadata.Package{
Registry: "npm",
Name: fmt.Sprintf("package-%d", i),
Version: "1.0.0",
StorageKey: fmt.Sprintf("npm/package-%d/1.0.0.tgz", i),
Size: int64(i * 1000),
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
}
// Get first page (5 items)
page1, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
Limit: 5,
Offset: 0,
})
s.NoError(err)
s.Len(page1, 5)
// Get second page (5 items)
page2, err := s.store.ListPackages(s.ctx, &metadata.ListOptions{
Limit: 5,
Offset: 5,
})
s.NoError(err)
s.Len(page2, 5)
// Verify different packages
s.NotEqual(page1[0].Name, page2[0].Name)
}
// Test_V2_UpdateDownloadCount_Success tests incrementing download count
func (s *GORMStoreV2TestSuite) Test_V2_UpdateDownloadCount_Success() {
// Create package
pkg := &metadata.Package{
Registry: "npm",
Name: "download-test",
Version: "1.0.0",
StorageKey: "npm/download-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Update download count
err = s.store.UpdateDownloadCount(s.ctx, "npm", "download-test", "1.0.0")
s.NoError(err)
// Verify count incremented
retrieved, err := s.store.GetPackage(s.ctx, "npm", "download-test", "1.0.0")
s.NoError(err)
s.Equal(int64(1), retrieved.DownloadCount)
// Update again
err = s.store.UpdateDownloadCount(s.ctx, "npm", "download-test", "1.0.0")
s.NoError(err)
retrieved, err = s.store.GetPackage(s.ctx, "npm", "download-test", "1.0.0")
s.NoError(err)
s.Equal(int64(2), retrieved.DownloadCount)
// Verify download event was recorded
var eventCount int64
s.store.db.Model(&DownloadEventModel{}).Count(&eventCount)
s.Equal(int64(2), eventCount)
}
// Test_V2_Count tests counting packages
func (s *GORMStoreV2TestSuite) Test_V2_Count() {
// Initially zero
count, err := s.store.Count(s.ctx)
s.NoError(err)
s.Equal(0, count)
// Create 3 packages
for i := 0; i < 3; i++ {
pkg := &metadata.Package{
Registry: "npm",
Name: fmt.Sprintf("count-test-%d", i),
Version: "1.0.0",
StorageKey: fmt.Sprintf("npm/count-test-%d/1.0.0.tgz", i),
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
}
count, err = s.store.Count(s.ctx)
s.NoError(err)
s.Equal(3, count)
}
// Test_V2_GetStats tests aggregated statistics
func (s *GORMStoreV2TestSuite) Test_V2_GetStats() {
// Create packages in different registries
packages := []*metadata.Package{
{Registry: "npm", Name: "pkg1", Version: "1.0.0", StorageKey: "npm/pkg1/1.0.0.tgz", Size: 1000, CachedAt: time.Now(), LastAccessed: time.Now()},
{Registry: "npm", Name: "pkg2", Version: "1.0.0", StorageKey: "npm/pkg2/1.0.0.tgz", Size: 2000, CachedAt: time.Now(), LastAccessed: time.Now()},
{Registry: "pypi", Name: "pkg3", Version: "1.0.0", StorageKey: "pypi/pkg3/1.0.0.tar.gz", Size: 3000, CachedAt: time.Now(), LastAccessed: time.Now()},
}
for _, pkg := range packages {
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
}
// Update download counts
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg1", "1.0.0")
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg1", "1.0.0")
s.store.UpdateDownloadCount(s.ctx, "npm", "pkg2", "1.0.0")
// Get stats for all registries
statsAll, err := s.store.GetStats(s.ctx, "")
s.NoError(err)
s.Equal(int64(3), statsAll.TotalPackages)
s.Equal(int64(6000), statsAll.TotalSize)
s.Equal(int64(3), statsAll.TotalDownloads)
// Get stats for npm registry
statsNpm, err := s.store.GetStats(s.ctx, "npm")
s.NoError(err)
s.Equal("npm", statsNpm.Registry)
s.Equal(int64(2), statsNpm.TotalPackages)
s.Equal(int64(3000), statsNpm.TotalSize)
s.Equal(int64(3), statsNpm.TotalDownloads)
// Get stats for pypi registry
statsPypi, err := s.store.GetStats(s.ctx, "pypi")
s.NoError(err)
s.Equal("pypi", statsPypi.Registry)
s.Equal(int64(1), statsPypi.TotalPackages)
s.Equal(int64(3000), statsPypi.TotalSize)
}
// Test_V2_Health tests database health check
func (s *GORMStoreV2TestSuite) Test_V2_Health() {
err := s.store.Health(s.ctx)
s.NoError(err)
}
// Test_V2_RegistryCache tests registry caching
func (s *GORMStoreV2TestSuite) Test_V2_RegistryCache() {
// Default registries should be cached
s.Contains(s.store.registryCache, "npm")
s.Contains(s.store.registryCache, "pypi")
s.Contains(s.store.registryCache, "go")
// Get registry ID from cache
npmID, err := s.store.getRegistryID("npm")
s.NoError(err)
s.Greater(npmID, int32(0))
// Second call should use cache (no DB query)
npmID2, err := s.store.getRegistryID("npm")
s.NoError(err)
s.Equal(npmID, npmID2)
// Non-existent registry
_, err = s.store.getRegistryID("nonexistent")
s.Error(err)
s.Contains(err.Error(), "not found")
}
// Test_V2_SoftDelete tests soft delete behavior
func (s *GORMStoreV2TestSuite) Test_V2_SoftDelete() {
// Create package
pkg := &metadata.Package{
Registry: "npm",
Name: "soft-delete",
Version: "1.0.0",
StorageKey: "npm/soft-delete/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Delete
err = s.store.DeletePackage(s.ctx, "npm", "soft-delete", "1.0.0")
s.NoError(err)
// Count should not include deleted
count, err := s.store.Count(s.ctx)
s.NoError(err)
s.Equal(0, count)
// But record still exists with deleted_at set
var pkgModel PackageModel
err = s.store.db.Unscoped().Where("name = ?", "soft-delete").First(&pkgModel).Error
s.NoError(err)
s.NotNil(pkgModel.DeletedAt)
}
// Test_V2_AggregationWorker tests that aggregation worker is initialized
func (s *GORMStoreV2TestSuite) Test_V2_AggregationWorker() {
s.NotNil(s.store.aggregationWorker)
}
// Test_V2_ConcurrentUpdates tests concurrent download count updates
func (s *GORMStoreV2TestSuite) Test_V2_ConcurrentUpdates() {
// Create package
pkg := &metadata.Package{
Registry: "npm",
Name: "concurrent-test",
Version: "1.0.0",
StorageKey: "npm/concurrent-test/1.0.0.tgz",
Size: 1000,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// SQLite: Sequential updates only (write lock prevents concurrent writes)
updateCount := 5
for i := 0; i < updateCount; i++ {
err := s.store.UpdateDownloadCount(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
}
// Verify all updates succeeded
retrieved, err := s.store.GetPackage(s.ctx, "npm", "concurrent-test", "1.0.0")
s.NoError(err)
s.Equal(int64(updateCount), retrieved.DownloadCount)
}
// Test_V2_SaveScanResult tests saving a scan result
func (s *GORMStoreV2TestSuite) Test_V2_SaveScanResult() {
// Create a package first
pkg := &metadata.Package{
Registry: "npm",
Name: "test-package",
Version: "1.0.0",
StorageKey: "/cache/npm/test-package-1.0.0.tgz",
Size: 1024,
UpstreamURL: "https://registry.npmjs.org/test-package",
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Create and save a scan result
scanResult := &metadata.ScanResult{
Registry: "npm",
PackageName: "test-package",
PackageVersion: "1.0.0",
Scanner: "trivy",
Status: metadata.ScanStatusVulnerable,
ScannedAt: time.Now(),
Vulnerabilities: []metadata.Vulnerability{
{
ID: "CVE-2024-0001",
Severity: "HIGH",
Title: "Test vulnerability",
Description: "Test description",
References: []string{"https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2024-0001"},
},
{
ID: "CVE-2024-0002",
Severity: "CRITICAL",
Title: "Critical vulnerability",
Description: "Critical test description",
References: []string{"https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2024-0002"},
},
},
VulnerabilityCount: 2,
Details: map[string]interface{}{
"scan_duration": 42,
"scanner_version": "1.0.0",
},
}
err = s.store.SaveScanResult(s.ctx, scanResult)
s.NoError(err)
// Verify the scan result was saved and package was updated
retrievedPkg, err := s.store.GetPackage(s.ctx, "npm", "test-package", "1.0.0")
s.NoError(err)
s.True(retrievedPkg.SecurityScanned)
}
// Test_V2_GetScanResult tests retrieving a scan result
func (s *GORMStoreV2TestSuite) Test_V2_GetScanResult() {
// Create a package
pkg := &metadata.Package{
Registry: "npm",
Name: "scan-test",
Version: "2.0.0",
StorageKey: "/cache/npm/scan-test-2.0.0.tgz",
Size: 2048,
UpstreamURL: "https://registry.npmjs.org/scan-test",
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
err := s.store.SavePackage(s.ctx, pkg)
s.NoError(err)
// Save a scan result with vulnerabilities
scanResult := &metadata.ScanResult{
Registry: "npm",
PackageName: "scan-test",
PackageVersion: "2.0.0",
Scanner: "grype",
Status: metadata.ScanStatusVulnerable,
ScannedAt: time.Now(),
Vulnerabilities: []metadata.Vulnerability{
{
ID: "CVE-2024-1234",
Severity: "HIGH",
Title: "Test High Severity",
Description: "High severity test",
References: []string{"https://example.com/cve-2024-1234"},
FixedIn: "2.1.0",
},
{
ID: "CVE-2024-5678",
Severity: "MODERATE",
Title: "Test Moderate Severity",
Description: "Moderate severity test",
References: []string{"https://example.com/cve-2024-5678"},
},
},
VulnerabilityCount: 2,
}
err = s.store.SaveScanResult(s.ctx, scanResult)
s.NoError(err)
// Retrieve the scan result
retrieved, err := s.store.GetScanResult(s.ctx, "npm", "scan-test", "2.0.0")
s.NoError(err)
s.NotNil(retrieved)
s.Equal("grype", retrieved.Scanner)
s.Equal(metadata.ScanStatusVulnerable, retrieved.Status)
s.Equal(2, retrieved.VulnerabilityCount)
s.Len(retrieved.Vulnerabilities, 2)
// Verify vulnerability details are retrieved correctly
s.Equal("CVE-2024-1234", retrieved.Vulnerabilities[0].ID)
s.Equal("HIGH", retrieved.Vulnerabilities[0].Severity)
s.Equal("Test High Severity", retrieved.Vulnerabilities[0].Title)
s.Equal("2.1.0", retrieved.Vulnerabilities[0].FixedIn)
s.Len(retrieved.Vulnerabilities[0].References, 1)
}
// Test_V2_GetScanResult_NotFound tests retrieving a non-existent scan result
func (s *GORMStoreV2TestSuite) Test_V2_GetScanResult_NotFound() {
_, err := s.store.GetScanResult(s.ctx, "npm", "nonexistent", "1.0.0")
s.Error(err)
}
// Test_V2_SaveCVEBypass tests saving a CVE bypass
func (s *GORMStoreV2TestSuite) Test_V2_SaveCVEBypass() {
bypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0001",
Reason: "False positive - not applicable to our use case",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour), // 30 days
NotifyOnExpiry: true,
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, bypass)
s.NoError(err)
s.NotEmpty(bypass.ID)
s.NotZero(bypass.CreatedAt)
}
// Test_V2_SaveCVEBypass_Update tests updating an existing CVE bypass
func (s *GORMStoreV2TestSuite) Test_V2_SaveCVEBypass_Update() {
// Create initial bypass
bypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0002",
Reason: "Initial reason",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
NotifyOnExpiry: false,
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, bypass)
s.NoError(err)
s.NotEmpty(bypass.ID)
// Update the bypass
bypass.Reason = "Updated reason"
bypass.NotifyOnExpiry = true
err = s.store.SaveCVEBypass(s.ctx, bypass)
s.NoError(err)
}
// Test_V2_GetActiveCVEBypasses tests retrieving active CVE bypasses
func (s *GORMStoreV2TestSuite) Test_V2_GetActiveCVEBypasses() {
// Create active bypass with unique target
uniqueTarget := fmt.Sprintf("CVE-2024-TEST-%d", time.Now().UnixNano())
activeBypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: uniqueTarget,
Reason: "Active bypass",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, activeBypass)
s.NoError(err)
// Create expired bypass
expiredBypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0004",
Reason: "Expired bypass",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired yesterday
Active: true,
}
err = s.store.SaveCVEBypass(s.ctx, expiredBypass)
s.NoError(err)
// Create inactive bypass
inactiveBypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0005",
Reason: "Inactive bypass",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
Active: false,
}
err = s.store.SaveCVEBypass(s.ctx, inactiveBypass)
s.NoError(err)
// Retrieve active bypasses
bypasses, err := s.store.GetActiveCVEBypasses(s.ctx)
s.NoError(err)
// Should contain our active bypass, but may contain others from parallel tests
found := false
for _, b := range bypasses {
if b.Target == uniqueTarget {
found = true
break
}
// All bypasses should be active and non-expired
s.True(b.Active)
s.True(b.ExpiresAt.After(time.Now()))
}
s.True(found, "Should find our unique active bypass")
}
// Test_V2_ListCVEBypasses tests listing CVE bypasses with filters
func (s *GORMStoreV2TestSuite) Test_V2_ListCVEBypasses() {
// Create multiple bypasses with unique targets
nano := time.Now().UnixNano()
bypasses := []*metadata.CVEBypass{
{
Type: metadata.BypassTypeCVE,
Target: fmt.Sprintf("CVE-2024-LIST-%d-1", nano),
Reason: "Test 1",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
Active: true,
},
{
Type: metadata.BypassTypePackage,
Target: fmt.Sprintf("npm/vulnerable-package@%d", nano),
Reason: "Test 2",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(15 * 24 * time.Hour),
Active: true,
},
{
Type: metadata.BypassTypeCVE,
Target: fmt.Sprintf("CVE-2024-LIST-%d-2", nano),
Reason: "Test 3",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
Active: true,
},
}
for _, b := range bypasses {
err := s.store.SaveCVEBypass(s.ctx, b)
s.NoError(err)
}
// List only CVE type
opts := &metadata.BypassListOptions{
Type: metadata.BypassTypeCVE,
}
cveOnly, err := s.store.ListCVEBypasses(s.ctx, opts)
s.NoError(err)
for _, b := range cveOnly {
s.Equal(metadata.BypassTypeCVE, b.Type)
}
// List only non-expired
opts = &metadata.BypassListOptions{
IncludeExpired: false,
}
nonExpired, err := s.store.ListCVEBypasses(s.ctx, opts)
s.NoError(err)
for _, b := range nonExpired {
s.True(b.ExpiresAt.After(time.Now()))
}
// Test pagination
opts = &metadata.BypassListOptions{
Limit: 1,
Offset: 0,
}
page1, err := s.store.ListCVEBypasses(s.ctx, opts)
s.NoError(err)
s.LessOrEqual(len(page1), 1) // Should be at most 1
}
// Test_V2_DeleteCVEBypass tests deleting a CVE bypass
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass() {
// Create a bypass
bypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0008",
Reason: "To be deleted",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, bypass)
s.NoError(err)
s.NotEmpty(bypass.ID)
// Delete the bypass
err = s.store.DeleteCVEBypass(s.ctx, bypass.ID)
s.NoError(err)
// Verify it's no longer in active bypasses
active, err := s.store.GetActiveCVEBypasses(s.ctx)
s.NoError(err)
for _, b := range active {
s.NotEqual(bypass.ID, b.ID)
}
}
// Test_V2_DeleteCVEBypass_NotFound tests deleting a non-existent bypass
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass_NotFound() {
err := s.store.DeleteCVEBypass(s.ctx, "99999999")
s.Error(err)
}
// Test_V2_DeleteCVEBypass_InvalidID tests deleting with invalid ID
func (s *GORMStoreV2TestSuite) Test_V2_DeleteCVEBypass_InvalidID() {
err := s.store.DeleteCVEBypass(s.ctx, "invalid-id")
s.Error(err)
}
// Test_V2_CleanupExpiredBypasses tests cleaning up expired bypasses
func (s *GORMStoreV2TestSuite) Test_V2_CleanupExpiredBypasses() {
// Create expired bypasses
for i := 0; i < 3; i++ {
bypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: fmt.Sprintf("CVE-2024-00%d", 10+i),
Reason: "Expired bypass",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, bypass)
s.NoError(err)
}
// Create active bypass (should not be deleted)
activeBypass := &metadata.CVEBypass{
Type: metadata.BypassTypeCVE,
Target: "CVE-2024-0999",
Reason: "Active bypass",
CreatedBy: "admin@example.com",
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
Active: true,
}
err := s.store.SaveCVEBypass(s.ctx, activeBypass)
s.NoError(err)
// Cleanup expired bypasses
count, err := s.store.CleanupExpiredBypasses(s.ctx)
s.NoError(err)
s.GreaterOrEqual(count, 3) // At least the 3 we just created
// Verify active bypass is still there
active, err := s.store.GetActiveCVEBypasses(s.ctx)
s.NoError(err)
found := false
for _, b := range active {
if b.Target == "CVE-2024-0999" {
found = true
break
}
}
s.True(found, "Active bypass should still exist")
}
+228
View File
@@ -0,0 +1,228 @@
package gormstore
import (
"github.com/go-gormigrate/gormigrate/v2"
"gorm.io/gorm"
)
// GetMigrations returns all database migrations for gormigrate
func GetMigrations() []*gormigrate.Migration {
return []*gormigrate.Migration{
{
ID: "202601030001",
Migrate: func(tx *gorm.DB) error {
// Migration: Create V2 schema
return migrateToV2(tx)
},
Rollback: func(tx *gorm.DB) error {
// Rollback: Drop V2 schema (careful!)
return rollbackFromV2(tx)
},
},
// Future migrations go here
// {
// ID: "202601040001",
// Migrate: func(tx *gorm.DB) error {
// // Add new column, index, etc.
// return tx.Exec("ALTER TABLE packages ADD COLUMN new_field VARCHAR(255)").Error
// },
// Rollback: func(tx *gorm.DB) error {
// return tx.Exec("ALTER TABLE packages DROP COLUMN new_field").Error
// },
// },
}
}
// migrateToV2 creates the complete V2 schema
func migrateToV2(tx *gorm.DB) error {
// Get dialect name for database-specific features
dialectName := tx.Dialector.Name()
// Step 1: Create all tables using GORM AutoMigrate
// This handles cross-database compatibility automatically
if err := tx.AutoMigrate(GetAllModels()...); err != nil {
return err
}
// Step 2: Seed default registries
registries := []RegistryModel{
{Name: "npm", DisplayName: "NPM Registry", UpstreamURL: "https://registry.npmjs.org", Enabled: true, ScanByDefault: true},
{Name: "pypi", DisplayName: "PyPI", UpstreamURL: "https://pypi.org", Enabled: true, ScanByDefault: true},
{Name: "go", DisplayName: "Go Modules", UpstreamURL: "https://proxy.golang.org", Enabled: true, ScanByDefault: true},
}
for _, reg := range registries {
// Upsert: create if not exists
if err := tx.Where("name = ?", reg.Name).FirstOrCreate(&reg).Error; err != nil {
return err
}
}
// Step 3: Create database-specific optimizations
if dialectName == "postgres" {
if err := createPostgreSQLOptimizations(tx); err != nil {
return err
}
} else if dialectName == "mysql" {
if err := createMySQLOptimizations(tx); err != nil {
return err
}
}
return nil
}
// createPostgreSQLOptimizations adds PostgreSQL-specific features
func createPostgreSQLOptimizations(tx *gorm.DB) error {
optimizations := []string{
// Create GIN indexes for JSONB columns
`CREATE INDEX IF NOT EXISTS idx_package_metadata_keywords_gin
ON package_metadata USING GIN(keywords)`,
`CREATE INDEX IF NOT EXISTS idx_package_metadata_raw_gin
ON package_metadata USING GIN(raw_metadata)`,
`CREATE INDEX IF NOT EXISTS idx_vulnerabilities_references_gin
ON vulnerabilities USING GIN(references)`,
// Create partial indexes (only non-deleted records)
`CREATE INDEX IF NOT EXISTS idx_packages_active
ON packages(registry_id, name, version) WHERE deleted_at IS NULL`,
`CREATE INDEX IF NOT EXISTS idx_packages_vulnerable
ON packages(vulnerability_count, highest_severity)
WHERE vulnerability_count > 0 AND deleted_at IS NULL`,
// Create view for vulnerable packages
`CREATE OR REPLACE VIEW v_vulnerable_packages AS
SELECT
r.name AS registry,
p.name,
p.version,
p.vulnerability_count,
p.highest_severity,
p.last_scanned_at
FROM packages p
JOIN registries r ON p.registry_id = r.id
WHERE p.vulnerability_count > 0 AND p.deleted_at IS NULL
ORDER BY
CASE p.highest_severity
WHEN 'critical' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
WHEN 'low' THEN 4
ELSE 5
END,
p.vulnerability_count DESC`,
// Create function for automatic partition creation
`CREATE OR REPLACE FUNCTION create_next_month_partitions()
RETURNS void AS $$
DECLARE
next_month DATE := date_trunc('month', NOW() + INTERVAL '2 months');
partition_name TEXT;
start_date TEXT;
end_date TEXT;
BEGIN
-- Download events partition
partition_name := 'download_events_' || to_char(next_month, 'YYYY_MM');
start_date := to_char(next_month, 'YYYY-MM-DD');
end_date := to_char(next_month + INTERVAL '1 month', 'YYYY-MM-DD');
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF download_events FOR VALUES FROM (%L) TO (%L)',
partition_name, start_date, end_date);
-- Audit log partition
partition_name := 'audit_log_' || to_char(next_month, 'YYYY_MM');
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF audit_log FOR VALUES FROM (%L) TO (%L)',
partition_name, start_date, end_date);
RAISE NOTICE 'Created partitions for %', to_char(next_month, 'YYYY-MM');
END;
$$ LANGUAGE plpgsql`,
}
for _, sql := range optimizations {
if err := tx.Exec(sql).Error; err != nil {
// Log warning but don't fail migration
// Some optimizations might already exist
continue
}
}
return nil
}
// createMySQLOptimizations adds MySQL-specific features
func createMySQLOptimizations(tx *gorm.DB) error {
optimizations := []string{
// Create view for vulnerable packages
`CREATE OR REPLACE VIEW v_vulnerable_packages AS
SELECT
r.name AS registry,
p.name,
p.version,
p.vulnerability_count,
p.highest_severity,
p.last_scanned_at
FROM packages p
JOIN registries r ON p.registry_id = r.id
WHERE p.vulnerability_count > 0 AND p.deleted_at IS NULL
ORDER BY
CASE p.highest_severity
WHEN 'critical' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
WHEN 'low' THEN 4
ELSE 5
END,
p.vulnerability_count DESC`,
}
for _, sql := range optimizations {
if err := tx.Exec(sql).Error; err != nil {
continue
}
}
return nil
}
// rollbackFromV2 drops all V2 tables (USE WITH CAUTION!)
func rollbackFromV2(tx *gorm.DB) error {
// Drop in reverse order to respect foreign keys
tables := []string{
"audit_log",
"download_stats_daily",
"download_stats_hourly",
"download_events",
"cve_bypasses",
"scan_results",
"package_vulnerabilities",
"vulnerabilities",
"package_metadata",
"packages",
"registries",
}
// Drop PostgreSQL-specific objects
if tx.Dialector.Name() == "postgres" {
tx.Exec("DROP VIEW IF EXISTS v_vulnerable_packages")
tx.Exec("DROP FUNCTION IF EXISTS create_next_month_partitions()")
}
// Drop MySQL-specific objects
if tx.Dialector.Name() == "mysql" {
tx.Exec("DROP VIEW IF EXISTS v_vulnerable_packages")
}
// Drop all tables
for _, table := range tables {
if err := tx.Migrator().DropTable(table); err != nil {
// Continue even if table doesn't exist
continue
}
}
return nil
}
+328
View File
@@ -0,0 +1,328 @@
package gormstore
import (
"database/sql/driver"
"encoding/json"
"time"
"gorm.io/gorm"
)
// BaseModel provides common fields for all models with audit trail
type BaseModel struct {
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"` // Soft delete support (auto-generated index name per table)
}
// RegistryModel represents package registries (normalized)
// This eliminates repetition of "npm", "pypi", "go" across millions of rows
type RegistryModel struct {
ID int32 `gorm:"primaryKey;autoIncrement"`
Name string `gorm:"uniqueIndex:idx_registry_name;not null;size:50"` // npm, pypi, go
DisplayName string `gorm:"not null;size:100"` // NPM Registry, PyPI, Go Modules
UpstreamURL string `gorm:"not null;size:512"` // https://registry.npmjs.org
Enabled bool `gorm:"not null;default:true;index:idx_registry_enabled"`
ScanByDefault bool `gorm:"not null;default:true"`
BaseModel
}
func (RegistryModel) TableName() string {
return "registries"
}
// PackageModel represents the core package data (optimized)
type PackageModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RegistryID int32 `gorm:"not null;index:idx_package_registry_name_version,priority:1"` // Foreign key to registries
Name string `gorm:"not null;size:255;index:idx_package_name;index:idx_package_registry_name_version,priority:2"`
Version string `gorm:"not null;size:100;index:idx_package_registry_name_version,priority:3"`
// Storage information
StorageKey string `gorm:"not null;uniqueIndex:idx_package_storage_key;size:512"`
Size int64 `gorm:"not null;index:idx_package_size"` // For storage quota queries
ChecksumMD5 string `gorm:"size:32;index:idx_package_md5"`
ChecksumSHA256 string `gorm:"size:64;index:idx_package_sha256"`
UpstreamURL string `gorm:"size:1024"`
// Cache management
CachedAt time.Time `gorm:"not null;index:idx_package_cached_at"`
LastAccessed time.Time `gorm:"not null;index:idx_package_last_accessed"` // For LRU eviction
ExpiresAt *time.Time `gorm:"index:idx_package_expires_at"` // For cache invalidation
AccessCount int64 `gorm:"not null;default:0;index:idx_package_access_count"` // Total access count (denormalized for performance)
// Security
SecurityScanned bool `gorm:"not null;default:false;index:idx_package_security_scanned"`
LastScannedAt *time.Time `gorm:"index:idx_package_last_scanned"`
VulnerabilityCount int `gorm:"not null;default:0;index:idx_package_vuln_count"` // Denormalized for fast filtering
HighestSeverity string `gorm:"size:20;index:idx_package_severity"` // critical, high, medium, low, none
CriticalCount int `gorm:"not null;default:0"` // Count of critical vulnerabilities
HighCount int `gorm:"not null;default:0"` // Count of high vulnerabilities
ModerateCount int `gorm:"not null;default:0"` // Count of moderate vulnerabilities
LowCount int `gorm:"not null;default:0"` // Count of low vulnerabilities
// Authentication
RequiresAuth bool `gorm:"not null;default:false;index:idx_package_requires_auth"`
AuthProvider string `gorm:"size:50;index:idx_package_auth_provider"` // github, gitlab, custom
BaseModel
// Relationships
Registry RegistryModel `gorm:"foreignKey:RegistryID;constraint:OnUpdate:CASCADE,OnDelete:RESTRICT"`
Metadata *PackageMetadataModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
ScanResults []ScanResultModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
Vulnerabilities []PackageVulnerabilityModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
}
func (PackageModel) TableName() string {
return "packages"
}
// BeforeCreate hook to set access count
func (p *PackageModel) BeforeCreate(tx *gorm.DB) error {
if p.AccessCount == 0 {
p.AccessCount = 0
}
return nil
}
// PackageMetadataModel stores structured package metadata (1:1 with packages)
// Separated from main table to reduce row size and improve query performance
type PackageMetadataModel struct {
PackageID int64 `gorm:"primaryKey;not null"` // 1:1 relationship
Author string `gorm:"size:255;index:idx_metadata_author"`
License string `gorm:"size:100;index:idx_metadata_license"`
Homepage string `gorm:"size:512"`
Repository string `gorm:"size:512"`
Description string `gorm:"type:text"`
Keywords PostgresArray `gorm:"type:text"` // JSONB array for PostgreSQL, JSON for MySQL/SQLite
RawMetadata JSONBField `gorm:"type:jsonb"` // Full metadata as JSONB (PostgreSQL) or JSON
BaseModel
}
func (PackageMetadataModel) TableName() string {
return "package_metadata"
}
// ScanResultModel represents security scan results (optimized)
type ScanResultModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
PackageID int64 `gorm:"not null;index:idx_scan_package_scanner,priority:1"` // Foreign key
Scanner string `gorm:"not null;size:50;index:idx_scan_scanner;index:idx_scan_package_scanner,priority:2"`
ScannedAt time.Time `gorm:"not null;index:idx_scan_scanned_at"`
Status string `gorm:"not null;size:20;index:idx_scan_status"` // success, failed, pending
VulnCount int `gorm:"not null;default:0;index:idx_scan_vuln_count"`
CriticalCount int `gorm:"not null;default:0"`
HighCount int `gorm:"not null;default:0"`
MediumCount int `gorm:"not null;default:0"`
LowCount int `gorm:"not null;default:0"`
ScanDuration int `gorm:"not null;default:0"` // milliseconds
Details JSONBField `gorm:"type:jsonb"` // Scanner-specific details
BaseModel
Package PackageModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
}
func (ScanResultModel) TableName() string {
return "scan_results"
}
// VulnerabilityModel represents unique vulnerabilities (normalized)
type VulnerabilityModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
CVEID string `gorm:"uniqueIndex:idx_vuln_cve_id;not null;size:50"` // CVE-2021-12345
Title string `gorm:"not null;size:512"`
Description string `gorm:"type:text"`
Severity string `gorm:"not null;size:20;index:idx_vuln_severity"` // critical, high, medium, low
CVSS float32 `gorm:"index:idx_vuln_cvss"` // CVSS score for sorting
PublishedAt time.Time `gorm:"not null;index:idx_vuln_published"`
FixedVersion string `gorm:"size:100"` // First version where it's fixed
References PostgresArray `gorm:"type:text"` // URLs to advisories
BaseModel
}
func (VulnerabilityModel) TableName() string {
return "vulnerabilities"
}
// PackageVulnerabilityModel is a many-to-many relationship between packages and vulnerabilities
type PackageVulnerabilityModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
PackageID int64 `gorm:"not null;index:idx_pkg_vuln_package,priority:1;index:idx_pkg_vuln_composite,priority:1"`
VulnerabilityID int64 `gorm:"not null;index:idx_pkg_vuln_vuln,priority:1;index:idx_pkg_vuln_composite,priority:2"`
Scanner string `gorm:"not null;size:50;index:idx_pkg_vuln_scanner"`
DetectedAt time.Time `gorm:"not null;index:idx_pkg_vuln_detected"`
Bypassed bool `gorm:"not null;default:false;index:idx_pkg_vuln_bypassed"`
BypassID *int64 `gorm:"index:idx_pkg_vuln_bypass_id"` // Reference to bypass if applicable
BaseModel
Package PackageModel `gorm:"foreignKey:PackageID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
Vulnerability VulnerabilityModel `gorm:"foreignKey:VulnerabilityID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
}
func (PackageVulnerabilityModel) TableName() string {
return "package_vulnerabilities"
}
// CVEBypassModel represents CVE bypass rules (improved)
type CVEBypassModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
Type string `gorm:"not null;size:20;index:idx_bypass_type"` // cve, package, registry
Target string `gorm:"not null;size:512;index:idx_bypass_target"` // CVE-ID, package name, etc.
Reason string `gorm:"not null;type:text"`
CreatedBy string `gorm:"not null;size:255;index:idx_bypass_created_by"`
ExpiresAt time.Time `gorm:"not null;index:idx_bypass_expires_at"`
NotifyOnExpiry bool `gorm:"not null;default:false"`
Active bool `gorm:"not null;default:true;index:idx_bypass_active"`
UsageCount int64 `gorm:"not null;default:0"` // How many times this bypass has been used
LastUsedAt *time.Time `gorm:"index:idx_bypass_last_used"`
// Scope limiting (optional)
RegistryID *int32 `gorm:"index:idx_bypass_registry"` // NULL = all registries
PackageID *int64 `gorm:"index:idx_bypass_package"` // NULL = all packages
BaseModel
}
func (CVEBypassModel) TableName() string {
return "cve_bypasses"
}
// DownloadEventModel represents raw download events (partitioned by month)
// This table should use PostgreSQL partitioning or time-series DB features
type DownloadEventModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
PackageID int64 `gorm:"not null;index:idx_download_package,priority:1"`
RegistryID int32 `gorm:"not null;index:idx_download_registry"`
DownloadedAt time.Time `gorm:"not null;index:idx_download_time;index:idx_download_package,priority:2"` // Partition key
UserAgent string `gorm:"size:512"` // For analytics
IPAddress string `gorm:"size:45;index:idx_download_ip"` // IPv6 support
Authenticated bool `gorm:"not null;default:false"`
Username string `gorm:"size:255;index:idx_download_user"`
// No BaseModel - this is append-only, no updates/deletes on individual rows
// Partitioned tables handle cleanup via DROP PARTITION
}
func (DownloadEventModel) TableName() string {
return "download_events"
}
// DownloadStatsHourlyModel represents pre-aggregated hourly statistics (partitioned)
type DownloadStatsHourlyModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RegistryID int32 `gorm:"not null;index:idx_stats_hourly_composite,priority:1"`
PackageID *int64 `gorm:"index:idx_stats_hourly_package"` // NULL = all packages in registry
TimeBucket time.Time `gorm:"not null;index:idx_stats_hourly_composite,priority:2"` // Truncated to hour
DownloadCount int64 `gorm:"not null;default:0"`
UniqueIPs int64 `gorm:"not null;default:0"` // Unique downloaders
AuthDownloads int64 `gorm:"not null;default:0"` // Authenticated downloads
BaseModel
}
func (DownloadStatsHourlyModel) TableName() string {
return "download_stats_hourly"
}
// DownloadStatsDailyModel represents pre-aggregated daily statistics
type DownloadStatsDailyModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
RegistryID int32 `gorm:"not null;index:idx_stats_daily_composite,priority:1"`
PackageID *int64 `gorm:"index:idx_stats_daily_package"` // NULL = all packages in registry
TimeBucket time.Time `gorm:"not null;index:idx_stats_daily_composite,priority:2"` // Truncated to day
DownloadCount int64 `gorm:"not null;default:0"`
UniqueIPs int64 `gorm:"not null;default:0"`
AuthDownloads int64 `gorm:"not null;default:0"`
TopUserAgents JSONBField `gorm:"type:jsonb"` // Top 10 user agents
BaseModel
}
func (DownloadStatsDailyModel) TableName() string {
return "download_stats_daily"
}
// AuditLogModel tracks all important changes (optional, for compliance)
type AuditLogModel struct {
ID int64 `gorm:"primaryKey;autoIncrement"`
EntityType string `gorm:"not null;size:50;index:idx_audit_entity_type"` // package, bypass, registry
EntityID int64 `gorm:"not null;index:idx_audit_entity_id"`
Action string `gorm:"not null;size:20;index:idx_audit_action"` // create, update, delete
Username string `gorm:"not null;size:255;index:idx_audit_username"`
Timestamp time.Time `gorm:"not null;index:idx_audit_timestamp"`
Changes JSONBField `gorm:"type:jsonb"` // Before/after values
IPAddress string `gorm:"size:45"`
UserAgent string `gorm:"size:512"`
// No BaseModel - append-only audit log
}
func (AuditLogModel) TableName() string {
return "audit_log"
}
// JSONBField is a custom type for JSONB (PostgreSQL) / JSON (MySQL/SQLite)
type JSONBField map[string]interface{}
func (j JSONBField) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONBField) Scan(value interface{}) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(bytes, j)
}
// PostgresArray is a custom type for PostgreSQL arrays stored as JSON
type PostgresArray []string
func (a PostgresArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
return json.Marshal(a)
}
func (a *PostgresArray) Scan(value interface{}) error {
if value == nil {
*a = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return nil
}
return json.Unmarshal(bytes, a)
}
// GetAllModels returns all models for GORM auto-migration
func GetAllModels() []interface{} {
return []interface{}{
&RegistryModel{},
&PackageModel{},
&PackageMetadataModel{},
&ScanResultModel{},
&VulnerabilityModel{},
&PackageVulnerabilityModel{},
&CVEBypassModel{},
&DownloadEventModel{},
&DownloadStatsHourlyModel{},
&DownloadStatsDailyModel{},
&AuditLogModel{},
}
}
+380
View File
@@ -0,0 +1,380 @@
package gormstore
import (
"fmt"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
// PartitionManager handles automatic partition creation and cleanup for PostgreSQL
type PartitionManager struct {
db *gorm.DB
}
// NewPartitionManager creates a new partition manager
func NewPartitionManager(db *gorm.DB) *PartitionManager {
return &PartitionManager{db: db}
}
// EnsurePartitions ensures required partitions exist for current and future months
func (pm *PartitionManager) EnsurePartitions() error {
// Check if we're using PostgreSQL
if pm.db.Dialector.Name() != "postgres" {
log.Debug().Msg("Partitioning only supported on PostgreSQL, skipping")
return nil
}
log.Info().Msg("Ensuring partitions exist")
// Create partitions for download_events
if err := pm.ensureDownloadEventPartitions(); err != nil {
return err
}
// Create partitions for audit_log
if err := pm.ensureAuditLogPartitions(); err != nil {
return err
}
// Set up automatic partition creation
if err := pm.createPartitionFunction(); err != nil {
log.Warn().Err(err).Msg("Failed to create partition function (may already exist)")
}
return nil
}
// ensureDownloadEventPartitions creates download_events partitions
func (pm *PartitionManager) ensureDownloadEventPartitions() error {
// Check if table is already partitioned
var isPartitioned bool
err := pm.db.Raw(`
SELECT EXISTS (
SELECT 1 FROM pg_partitioned_table
WHERE partrelid = 'download_events'::regclass
)
`).Scan(&isPartitioned).Error
if err != nil {
return err
}
if !isPartitioned {
log.Info().Msg("Converting download_events to partitioned table")
// Rename existing table
if err := pm.db.Exec("ALTER TABLE IF EXISTS download_events RENAME TO download_events_old").Error; err != nil {
log.Warn().Err(err).Msg("Could not rename old table (may not exist)")
}
// Create partitioned table
createTableSQL := `
CREATE TABLE IF NOT EXISTS download_events (
id BIGSERIAL,
package_id BIGINT NOT NULL,
registry_id INTEGER NOT NULL,
downloaded_at TIMESTAMP NOT NULL,
user_agent VARCHAR(512),
ip_address VARCHAR(45),
authenticated BOOLEAN NOT NULL DEFAULT FALSE,
username VARCHAR(255)
) PARTITION BY RANGE (downloaded_at)
`
if err := pm.db.Exec(createTableSQL).Error; err != nil {
return fmt.Errorf("failed to create partitioned table: %w", err)
}
log.Info().Msg("Created partitioned download_events table")
}
// Create partitions for past 3 months, current month, and next 3 months
now := time.Now()
for i := -3; i <= 3; i++ {
month := now.AddDate(0, i, 0)
if err := pm.createDownloadEventPartition(month); err != nil {
log.Error().Err(err).Time("month", month).Msg("Failed to create partition")
}
}
return nil
}
// createDownloadEventPartition creates a partition for a specific month
func (pm *PartitionManager) createDownloadEventPartition(month time.Time) error {
// Truncate to start of month
startOfMonth := time.Date(month.Year(), month.Month(), 1, 0, 0, 0, 0, time.UTC)
endOfMonth := startOfMonth.AddDate(0, 1, 0)
partitionName := fmt.Sprintf("download_events_%d_%02d", month.Year(), month.Month())
// Check if partition already exists
var exists bool
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = ?)", partitionName).Scan(&exists).Error
if err != nil {
return err
}
if exists {
log.Debug().Str("partition", partitionName).Msg("Partition already exists")
return nil
}
// Create partition
createPartitionSQL := fmt.Sprintf(`
CREATE TABLE %s PARTITION OF download_events
FOR VALUES FROM ('%s') TO ('%s')
`, partitionName, startOfMonth.Format("2006-01-02"), endOfMonth.Format("2006-01-02"))
if err := pm.db.Exec(createPartitionSQL).Error; err != nil {
return fmt.Errorf("failed to create partition %s: %w", partitionName, err)
}
// Create indexes on partition
indexSQL := []string{
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_package_idx ON %s(package_id, downloaded_at)", partitionName, partitionName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_registry_idx ON %s(registry_id)", partitionName, partitionName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_time_idx ON %s(downloaded_at)", partitionName, partitionName),
}
for _, sql := range indexSQL {
if err := pm.db.Exec(sql).Error; err != nil {
log.Warn().Err(err).Str("sql", sql).Msg("Failed to create index")
}
}
log.Info().Str("partition", partitionName).Msg("Created partition")
return nil
}
// ensureAuditLogPartitions creates audit_log partitions
func (pm *PartitionManager) ensureAuditLogPartitions() error {
// Check if table exists
var exists bool
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = 'audit_log')").Scan(&exists).Error
if err != nil {
return err
}
if !exists {
// Create partitioned table
createTableSQL := `
CREATE TABLE IF NOT EXISTS audit_log (
id BIGSERIAL,
entity_type VARCHAR(50) NOT NULL,
entity_id BIGINT NOT NULL,
action VARCHAR(20) NOT NULL,
username VARCHAR(255) NOT NULL,
timestamp TIMESTAMP NOT NULL,
changes JSONB,
ip_address VARCHAR(45),
user_agent VARCHAR(512)
) PARTITION BY RANGE (timestamp)
`
if err := pm.db.Exec(createTableSQL).Error; err != nil {
return fmt.Errorf("failed to create audit_log table: %w", err)
}
log.Info().Msg("Created partitioned audit_log table")
}
// Create partitions for past month, current month, and next 2 months
now := time.Now()
for i := -1; i <= 2; i++ {
month := now.AddDate(0, i, 0)
if err := pm.createAuditLogPartition(month); err != nil {
log.Error().Err(err).Time("month", month).Msg("Failed to create audit partition")
}
}
return nil
}
// createAuditLogPartition creates a partition for a specific month
func (pm *PartitionManager) createAuditLogPartition(month time.Time) error {
startOfMonth := time.Date(month.Year(), month.Month(), 1, 0, 0, 0, 0, time.UTC)
endOfMonth := startOfMonth.AddDate(0, 1, 0)
partitionName := fmt.Sprintf("audit_log_%d_%02d", month.Year(), month.Month())
// Check if partition already exists
var exists bool
err := pm.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = ?)", partitionName).Scan(&exists).Error
if err != nil {
return err
}
if exists {
return nil
}
// Create partition
createPartitionSQL := fmt.Sprintf(`
CREATE TABLE %s PARTITION OF audit_log
FOR VALUES FROM ('%s') TO ('%s')
`, partitionName, startOfMonth.Format("2006-01-02"), endOfMonth.Format("2006-01-02"))
if err := pm.db.Exec(createPartitionSQL).Error; err != nil {
return fmt.Errorf("failed to create partition %s: %w", partitionName, err)
}
// Create indexes
indexSQL := []string{
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_entity_idx ON %s(entity_type, entity_id)", partitionName, partitionName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_user_idx ON %s(username)", partitionName, partitionName),
fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_time_idx ON %s(timestamp)", partitionName, partitionName),
}
for _, sql := range indexSQL {
if err := pm.db.Exec(sql).Error; err != nil {
log.Warn().Err(err).Msg("Failed to create audit index")
}
}
log.Info().Str("partition", partitionName).Msg("Created audit partition")
return nil
}
// createPartitionFunction creates a PostgreSQL function for automatic partition creation
func (pm *PartitionManager) createPartitionFunction() error {
functionSQL := `
CREATE OR REPLACE FUNCTION create_next_month_partitions()
RETURNS void AS $$
DECLARE
next_month DATE := date_trunc('month', NOW() + INTERVAL '2 months');
partition_name TEXT;
start_date TEXT;
end_date TEXT;
BEGIN
-- Create download_events partition
partition_name := 'download_events_' || to_char(next_month, 'YYYY_MM');
start_date := to_char(next_month, 'YYYY-MM-DD');
end_date := to_char(next_month + INTERVAL '1 month', 'YYYY-MM-DD');
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF download_events FOR VALUES FROM (%L) TO (%L)',
partition_name, start_date, end_date);
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(package_id, downloaded_at)',
partition_name || '_package_idx', partition_name);
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(registry_id)',
partition_name || '_registry_idx', partition_name);
-- Create audit_log partition
partition_name := 'audit_log_' || to_char(next_month, 'YYYY_MM');
EXECUTE format('CREATE TABLE IF NOT EXISTS %I PARTITION OF audit_log FOR VALUES FROM (%L) TO (%L)',
partition_name, start_date, end_date);
EXECUTE format('CREATE INDEX IF NOT EXISTS %I ON %I(entity_type, entity_id)',
partition_name || '_entity_idx', partition_name);
RAISE NOTICE 'Created partitions for %', to_char(next_month, 'YYYY-MM');
END;
$$ LANGUAGE plpgsql;
`
if err := pm.db.Exec(functionSQL).Error; err != nil {
return err
}
log.Info().Msg("Created partition management function")
return nil
}
// CleanupOldPartitions drops partitions older than the retention period
func (pm *PartitionManager) CleanupOldPartitions(retentionMonths int) error {
if pm.db.Dialector.Name() != "postgres" {
return nil
}
cutoffDate := time.Now().AddDate(0, -retentionMonths, 0)
cutoffPartition := fmt.Sprintf("%d_%02d", cutoffDate.Year(), cutoffDate.Month())
log.Info().
Str("cutoff", cutoffPartition).
Int("retention_months", retentionMonths).
Msg("Cleaning up old partitions")
// Find and drop old download_events partitions
var downloadPartitions []string
err := pm.db.Raw(`
SELECT tablename FROM pg_tables
WHERE tablename LIKE 'download_events_%'
AND tablename < 'download_events_' || ?
`, cutoffPartition).Scan(&downloadPartitions).Error
if err != nil {
return err
}
for _, partition := range downloadPartitions {
log.Info().Str("partition", partition).Msg("Dropping old partition")
if err := pm.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", partition)).Error; err != nil {
log.Error().Err(err).Str("partition", partition).Msg("Failed to drop partition")
}
}
// Find and drop old audit_log partitions
var auditPartitions []string
err = pm.db.Raw(`
SELECT tablename FROM pg_tables
WHERE tablename LIKE 'audit_log_%'
AND tablename < 'audit_log_' || ?
`, cutoffPartition).Scan(&auditPartitions).Error
if err != nil {
return err
}
for _, partition := range auditPartitions {
log.Info().Str("partition", partition).Msg("Dropping old audit partition")
if err := pm.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", partition)).Error; err != nil {
log.Error().Err(err).Str("partition", partition).Msg("Failed to drop audit partition")
}
}
return nil
}
// GetPartitionInfo returns information about current partitions
func (pm *PartitionManager) GetPartitionInfo() (map[string]interface{}, error) {
if pm.db.Dialector.Name() != "postgres" {
return map[string]interface{}{"status": "not_applicable"}, nil
}
info := make(map[string]interface{})
// Count download_events partitions
var downloadCount int64
pm.db.Raw("SELECT COUNT(*) FROM pg_tables WHERE tablename LIKE 'download_events_%'").Scan(&downloadCount)
info["download_events_partitions"] = downloadCount
// Count audit_log partitions
var auditCount int64
pm.db.Raw("SELECT COUNT(*) FROM pg_tables WHERE tablename LIKE 'audit_log_%'").Scan(&auditCount)
info["audit_log_partitions"] = auditCount
// Get partition sizes
type PartitionSize struct {
TableName string
SizeMB float64
}
var partitionSizes []PartitionSize
pm.db.Raw(`
SELECT
tablename AS table_name,
pg_total_relation_size(tablename::regclass) / 1024.0 / 1024.0 AS size_mb
FROM pg_tables
WHERE tablename LIKE 'download_events_%' OR tablename LIKE 'audit_log_%'
ORDER BY size_mb DESC
LIMIT 10
`).Scan(&partitionSizes)
info["largest_partitions"] = partitionSizes
return info, nil
}
+12 -7
View File
@@ -143,13 +143,18 @@ const (
// Stats represents metadata statistics
type Stats struct {
LastUpdated time.Time `json:"last_updated"`
Registry string `json:"registry"`
TotalPackages int64 `json:"total_packages"`
TotalSize int64 `json:"total_size"`
TotalDownloads int64 `json:"total_downloads"`
ScannedPackages int64 `json:"scanned_packages"`
VulnerablePackages int64 `json:"vulnerable_packages"`
LastUpdated time.Time `json:"last_updated"`
Registry string `json:"registry"`
TotalPackages int64 `json:"total_packages"`
TotalSize int64 `json:"total_size"`
TotalDownloads int64 `json:"total_downloads"`
ScannedPackages int64 `json:"scanned_packages"`
VulnerablePackages int64 `json:"vulnerable_packages"`
BlockedPackages int64 `json:"blocked_packages"`
CriticalVulnerabilities int64 `json:"critical_vulnerabilities"`
HighVulnerabilities int64 `json:"high_vulnerabilities"`
ModerateVulnerabilities int64 `json:"moderate_vulnerabilities"`
LowVulnerabilities int64 `json:"low_vulnerabilities"`
}
// TimeSeriesDataPoint represents a single data point in time-series
File diff suppressed because it is too large Load Diff
+16 -3
View File
@@ -148,6 +148,17 @@ func (h *Handler) handlePackagePage(ctx context.Context, w http.ResponseWriter,
func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
packageName, version := extractPackageFileInfo(path)
// Make version unique by appending file type to avoid cache collisions
// between .whl and .metadata files with same version
cacheVersion := version
if strings.HasSuffix(path, ".metadata") {
cacheVersion = version + ".metadata"
} else if strings.HasSuffix(path, ".whl") {
cacheVersion = version + ".whl"
} else if strings.HasSuffix(path, ".tar.gz") {
cacheVersion = version + ".tar.gz"
}
// Extract credentials from request
credentials := h.credExtractor.Extract(r)
credHash := h.credHasher.Hash(credentials)
@@ -170,12 +181,13 @@ func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter,
Str("path", path).
Str("package", packageName).
Str("version", version).
Str("cache_version", cacheVersion).
Str("url", originalURL).
Str("cred_hash", credHash).
Bool("has_credentials", credentials != "").
Msg("Handling PyPI package file request")
entry, err := h.cache.Get(ctx, "pypi", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
entry, err := h.cache.Get(ctx, "pypi", packageName, cacheVersion, func(ctx context.Context) (io.ReadCloser, string, error) {
// Prepare headers for upstream request
headers := make(map[string]string)
if credentials != "" {
@@ -281,11 +293,12 @@ func isPackagePage(path string) bool {
// isPackageFile checks if the request is for a package file
func isPackageFile(path string) bool {
// Package files (not including .metadata files which need special handling)
// Package files including .metadata files for PEP 658 support
return strings.HasSuffix(path, ".whl") ||
strings.HasSuffix(path, ".tar.gz") ||
strings.HasSuffix(path, ".zip") ||
strings.HasSuffix(path, ".egg")
strings.HasSuffix(path, ".egg") ||
strings.HasSuffix(path, ".metadata")
}
// extractPackageName extracts package name from path