diff --git a/Makefile b/Makefile index 3fb2390..bd23ec4 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,22 @@ bench: ## Run benchmarks @echo "Running benchmarks..." @go test -bench=. -benchmem ./... -run: build ## Build and run the server +run: build ## Build and run both backend and frontend for development + @echo "Starting $(BINARY_NAME) and frontend in development mode..." + @echo "" + @echo "Backend will run on: http://localhost:8080 (configured in config.yaml)" + @echo "Frontend will run on: http://localhost:5173 (configured in frontend/.env)" + @echo "" + @echo "To change ports:" + @echo " - Backend: Edit 'server.port' in config.yaml" + @echo " - Frontend: Edit 'VITE_PORT' and 'VITE_BACKEND_URL' in frontend/.env" + @echo "" + @trap 'kill 0' SIGINT; \ + $(BINARY_PATH) serve & \ + cd frontend && pnpm dev & \ + wait + +run-backend: build ## Build and run only the backend server @echo "Starting $(BINARY_NAME)..." @$(BINARY_PATH) serve @@ -62,6 +77,20 @@ clean: ## Clean build artifacts @rm -f *.db *.db-shm *.db-wal @echo "Clean complete" +clean-db: ## Clean all local cache and database files (from config.yaml paths) + @echo "WARNING: This will delete all cached packages and scan results!" + @echo "Paths from config.yaml:" + @echo " - ./data/storage (package cache)" + @echo " - ./data/gohoarder.db (metadata database)" + @echo " - /tmp/trivy (Trivy cache)" + @echo "" + @read -p "Are you sure you want to continue? [y/N] " confirm && [ "$$confirm" = "y" ] || exit 1 + @echo "Cleaning database and cache..." + @rm -rf ./data/storage + @rm -f ./data/gohoarder.db ./data/gohoarder.db-shm ./data/gohoarder.db-wal + @rm -rf /tmp/trivy + @echo "Database and cache cleaned successfully" + install: build ## Install the binary @echo "Installing $(BINARY_NAME)..." @cp $(BINARY_PATH) $(GOPATH)/bin/ @@ -92,4 +121,11 @@ docker-run: docker-build ## Run Docker container @echo "Running Docker container..." @docker run -p 8080:8080 $(BINARY_NAME):$(VERSION) +test-packages: ## Download test packages through gohoarder proxy (clean + vulnerable packages) + @echo "Reading backend port from config.yaml..." + @PORT=$$(grep "^ port:" config.yaml | awk '{print $$2}'); \ + if [ -z "$$PORT" ]; then PORT=8080; fi; \ + export GOHOARDER_URL="http://localhost:$$PORT"; \ + ./script/test-packages.sh + .DEFAULT_GOAL := help diff --git a/config.yaml.example b/config.yaml.example index dad30b7..49a5f75 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -1,8 +1,14 @@ # GoHoarder Configuration Example +# +# Port Configuration: +# - Backend server port is configured below (server.port) +# - Frontend dev server uses frontend/.env (VITE_PORT and VITE_BACKEND_URL) +# - When running `make run`, both will start with their configured ports +# - The frontend automatically proxies /api and /ws requests to the backend server: host: "0.0.0.0" - port: 8080 + port: 8080 # Backend API server port read_timeout: "5m" write_timeout: "5m" idle_timeout: "2m" @@ -62,18 +68,66 @@ cache: security: enabled: false block_on_severity: "high" # none, low, medium, high, critical + scan_on_download: true # Scan packages on first download + rescan_interval: "24h" # How often to re-scan packages (e.g., 24h, 168h for weekly) + update_db_on_startup: false # Update vulnerability databases on startup + allowed_packages: [] # Packages that bypass security checks (format: "registry/name@version") + ignored_cves: [] # CVE IDs to ignore globally (e.g., "CVE-2021-23337") + + block_thresholds: + critical: 0 # Max critical vulns (0 = block any) + high: -1 # Max high vulns (-1 = unlimited) + medium: -1 # Max medium vulns + low: -1 # Max low vulns scanners: + # Trivy - Comprehensive vulnerability scanner from Aqua Security + # Supports: containers, OS packages, language packages trivy: enabled: false timeout: "5m" cache_db: "/var/lib/trivy" + # OSV - Google's Open Source Vulnerabilities database + # Supports: npm, PyPI, Go, Maven, NuGet, etc. osv: enabled: false api_url: "https://api.osv.dev" timeout: "30s" + # Grype - Multi-ecosystem vulnerability scanner from Anchore + # Supports: all package types, containers, SBOMs + grype: + enabled: false + timeout: "5m" + + # govulncheck - Official Go vulnerability scanner from the Go team + # Supports: Go modules only + govulncheck: + enabled: false + timeout: "5m" + + # npm-audit - npm's built-in vulnerability scanner + # Supports: npm packages only + npm_audit: + enabled: false + timeout: "2m" + + # pip-audit - Python package vulnerability scanner + # Supports: PyPI packages only + pip_audit: + enabled: false + timeout: "2m" + + # GitHub Advisory Database - GitHub's security advisory database + # Supports: npm, pip, go, maven, nuget, cargo, pub + # Optional: Set token for higher API rate limits (60 req/hour unauthenticated, 5000 req/hour authenticated) + ghsa: + enabled: false + timeout: "30s" + token: "" # Optional: GitHub personal access token (ghp_...) + + # Static Analysis - Basic static analysis and package validation static: enabled: true max_package_size: 2147483648 # 2GB diff --git a/frontend/.env.example b/frontend/.env.example new file mode 100644 index 0000000..22001c8 --- /dev/null +++ b/frontend/.env.example @@ -0,0 +1,7 @@ +# Backend API URL (used by Vite dev server proxy) +# Change this if your gohoarder backend is running on a different port +VITE_BACKEND_URL=http://localhost:8080 + +# Frontend dev server port +# The Vite development server will run on this port +VITE_PORT=5173 diff --git a/frontend/src/components/Dashboard.vue b/frontend/src/components/Dashboard.vue index 4db944a..7281ecc 100644 --- a/frontend/src/components/Dashboard.vue +++ b/frontend/src/components/Dashboard.vue @@ -110,10 +110,10 @@ {{ getChartLabel(index) }} -
+

- Chart data will be available once backend API exposes time-series statistics + {{ chartLoading ? 'Loading chart data...' : 'No download activity in this period' }}

@@ -121,7 +121,7 @@ -

+

Recent Packages

@@ -178,13 +178,15 @@ diff --git a/frontend/src/components/PackageList.vue b/frontend/src/components/PackageList.vue index 45c17bc..6c49a9a 100644 --- a/frontend/src/components/PackageList.vue +++ b/frontend/src/components/PackageList.vue @@ -258,6 +258,7 @@ import { Card, CardContent } from '@/components/ui/card' import { Badge } from '@/components/ui/badge' import { Input } from '@/components/ui/input' import VulnerabilityBadge from './VulnerabilityBadge.vue' +import { getRegistryBadgeClass } from '@/composables/useBadgeStyles' // Props from router const props = defineProps<{ @@ -382,15 +383,6 @@ async function deletePackage() { } } -function getRegistryBadgeClass(registry: string): string { - const classes: Record = { - npm: 'bg-blue-100 text-blue-800 border-blue-200', - pypi: 'bg-green-100 text-green-800 border-green-200', - go: 'bg-yellow-100 text-yellow-800 border-yellow-200', - } - return classes[registry] || 'bg-gray-100 text-gray-800 border-gray-200' -} - function formatNumber(num: number): string { return new Intl.NumberFormat().format(num) } diff --git a/frontend/src/components/Stats.vue b/frontend/src/components/Stats.vue index ad6fd7b..5dc5723 100644 --- a/frontend/src/components/Stats.vue +++ b/frontend/src/components/Stats.vue @@ -19,7 +19,7 @@ -

+

Overall Statistics

@@ -48,7 +48,7 @@ -

+

Security Scanning

@@ -61,12 +61,19 @@
-
+

{{ formatNumber(stats?.vulnerable_packages || 0) }}

-

Vulnerable Packages

+

+ Vulnerable Packages + (click to view) +

@@ -77,7 +84,7 @@ -

+

Registry Breakdown

@@ -113,17 +120,27 @@ diff --git a/frontend/src/composables/useBadgeStyles.ts b/frontend/src/composables/useBadgeStyles.ts new file mode 100644 index 0000000..c62b7fe --- /dev/null +++ b/frontend/src/composables/useBadgeStyles.ts @@ -0,0 +1,59 @@ +/** + * Shared badge styling utilities for consistent UI across the application + */ + +/** + * Get Tailwind CSS classes for severity badges (light theme) + * @param severity - Severity level (CRITICAL, HIGH, MODERATE/MEDIUM, LOW) + * @returns Tailwind CSS class string + */ +export function getSeverityBadgeClass(severity: string): string { + const classes: Record = { + CRITICAL: 'bg-red-100 text-red-800 border-red-300', + HIGH: 'bg-orange-100 text-orange-800 border-orange-300', + MEDIUM: 'bg-yellow-100 text-yellow-800 border-yellow-300', + MODERATE: 'bg-yellow-100 text-yellow-800 border-yellow-300', + LOW: 'bg-blue-100 text-blue-800 border-blue-300', + } + return classes[severity.toUpperCase()] || 'bg-gray-100 text-gray-800 border-gray-300' +} + +/** + * Get Tailwind CSS classes for registry badges (light theme) + * @param registry - Registry name (npm, pypi, go) + * @returns Tailwind CSS class string + */ +export function getRegistryBadgeClass(registry: string): string { + const classes: Record = { + npm: 'bg-red-100 text-red-800 border-red-300', + pypi: 'bg-blue-100 text-blue-800 border-blue-300', + go: 'bg-cyan-100 text-cyan-800 border-cyan-300', + } + return classes[registry.toLowerCase()] || 'bg-gray-100 text-gray-800 border-gray-300' +} + +/** + * Get Tailwind CSS classes for vulnerability border indicators + * @param severity - Severity level (CRITICAL, HIGH, MODERATE/MEDIUM, LOW) + * @returns Tailwind CSS class string for left border + */ +export function getVulnerabilityBorderClass(severity: string): string { + const classes: Record = { + CRITICAL: 'border-l-4 border-l-red-600', + HIGH: 'border-l-4 border-l-orange-500', + MEDIUM: 'border-l-4 border-l-yellow-500', + MODERATE: 'border-l-4 border-l-yellow-500', + LOW: 'border-l-4 border-l-blue-500', + } + return classes[severity.toUpperCase()] || 'border-l-4 border-l-gray-500' +} + +/** + * Format severity name for display (title case) + * @param severity - Severity level (e.g., "CRITICAL", "HIGH") + * @returns Formatted severity name (e.g., "Critical", "High") + */ +export function formatSeverityName(severity: string): string { + const normalized = severity.toUpperCase() + return normalized.charAt(0) + normalized.slice(1).toLowerCase() +} diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 9d2cd4a..206dfe3 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -3,6 +3,7 @@ import Dashboard from '../components/Dashboard.vue' import PackageList from '../components/PackageList.vue' import PackageDetails from '../components/PackageDetails.vue' import Stats from '../components/Stats.vue' +import VulnerablePackages from '../components/VulnerablePackages.vue' import BypassManagementPanel from '../components/BypassManagementPanel.vue' const router = createRouter({ @@ -31,6 +32,11 @@ const router = createRouter({ name: 'stats', component: Stats, }, + { + path: '/vulnerable-packages', + name: 'vulnerable-packages', + component: VulnerablePackages, + }, { path: '/admin/bypasses', name: 'bypasses', diff --git a/frontend/src/stores/packages.ts b/frontend/src/stores/packages.ts index 1020e2a..6380be2 100644 --- a/frontend/src/stores/packages.ts +++ b/frontend/src/stores/packages.ts @@ -51,8 +51,8 @@ export const usePackageStore = defineStore('packages', () => { try { const response = await axios.get('/api/packages') // Only update packages if we got valid data - if (response.data && response.data.data && Array.isArray(response.data.data.packages)) { - packages.value = response.data.data.packages + if (response.data && Array.isArray(response.data.packages)) { + packages.value = response.data.packages } else { console.warn('Unexpected API response format:', response.data) error.value = 'Unexpected response format from server' @@ -73,9 +73,9 @@ export const usePackageStore = defineStore('packages', () => { const url = registry ? `/api/stats?registry=${registry}` : '/api/stats' const response = await axios.get(url) // Only update stats if we got valid data - if (response.data && response.data.data && response.data.data.stats) { - stats.value = response.data.data.stats - registries.value = response.data.data.registries || {} + if (response.data && response.data.stats) { + stats.value = response.data.stats + registries.value = response.data.registries || {} } else { console.warn('Unexpected stats response format:', response.data) error.value = 'Unexpected stats response format from server' diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 39193fd..f73c940 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -2,6 +2,10 @@ import { defineConfig } from 'vite' import vue from '@vitejs/plugin-vue' import path from 'path' +// Get backend URL from environment or use default +const BACKEND_URL = process.env.VITE_BACKEND_URL || 'http://localhost:8080' +const FRONTEND_PORT = parseInt(process.env.VITE_PORT || '5173') + export default defineConfig({ plugins: [vue()], resolve: { @@ -10,10 +14,15 @@ export default defineConfig({ }, }, server: { - port: 3000, + port: FRONTEND_PORT, proxy: { '/api': { - target: 'http://localhost:8080', + target: BACKEND_URL, + changeOrigin: true, + }, + '/ws': { + target: BACKEND_URL.replace('http', 'ws'), + ws: true, changeOrigin: true, }, }, diff --git a/go.mod b/go.mod index cfbedbd..56918b0 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.19.6 github.com/aws/aws-sdk-go-v2/service/s3 v1.95.0 github.com/goccy/go-json v0.10.5 + github.com/gofiber/fiber/v2 v2.52.10 github.com/gorilla/websocket v1.5.3 github.com/hirochachacha/go-smb2 v1.1.0 github.com/prometheus/client_golang v1.23.2 @@ -23,6 +24,7 @@ require ( ) require ( + github.com/andybalholm/brotli v1.1.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect @@ -48,8 +50,10 @@ require ( github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect @@ -58,12 +62,16 @@ require ( github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.51.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect diff --git a/go.sum b/go.sum index 2f40943..e32f1cd 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 h1:489krEF9xIGkOaaX3CE/Be2uWjiXrkCH6gUX+bZA/BU= @@ -64,6 +66,8 @@ github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlnd github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofiber/fiber/v2 v2.52.10 h1:jRHROi2BuNti6NYXmZ6gbNSfT3zj/8c0xy94GOU5elY= +github.com/gofiber/fiber/v2 v2.52.10/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= @@ -93,6 +97,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= @@ -114,6 +120,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= @@ -139,6 +147,12 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= +github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= diff --git a/gohoarder.pid b/gohoarder.pid index b2fcc59..73a0648 100644 --- a/gohoarder.pid +++ b/gohoarder.pid @@ -1 +1 @@ -20805 +bd25b1e diff --git a/pkg/app/app.go b/pkg/app/app.go index 526dc7b..8698ece 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -9,6 +9,8 @@ import ( "syscall" "time" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" "github.com/lukaszraczylo/gohoarder/pkg/analytics" "github.com/lukaszraczylo/gohoarder/pkg/auth" "github.com/lukaszraczylo/gohoarder/pkg/cache" @@ -16,7 +18,6 @@ import ( "github.com/lukaszraczylo/gohoarder/pkg/config" "github.com/lukaszraczylo/gohoarder/pkg/health" "github.com/lukaszraczylo/gohoarder/pkg/lock" - "github.com/lukaszraczylo/gohoarder/pkg/logger" "github.com/lukaszraczylo/gohoarder/pkg/metadata" metafile "github.com/lukaszraczylo/gohoarder/pkg/metadata/file" metasqlite "github.com/lukaszraczylo/gohoarder/pkg/metadata/sqlite" @@ -36,7 +37,7 @@ import ( // App represents the main application type App struct { config *config.Config - server *http.Server + app *fiber.App healthChecker *health.Checker cache *cache.Manager storage storage.StorageBackend @@ -163,7 +164,7 @@ func (a *App) initializeComponents() error { a.wsServer = websocket.NewServer(websocket.Config{ ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { + CheckOrigin: func(_ *http.Request) bool { return true // Allow all origins in development }, }) @@ -221,55 +222,60 @@ func (a *App) initializeComponents() error { return nil } -// setupServer sets up the HTTP server and routes +// setupServer sets up the Fiber server and routes func (a *App) setupServer() error { - mux := http.NewServeMux() + // Create Fiber app + a.app = fiber.New(fiber.Config{ + ReadTimeout: a.config.Server.ReadTimeout, + WriteTimeout: a.config.Server.WriteTimeout, + ServerHeader: "GoHoarder", + AppName: "GoHoarder v1.0", + }) - // Health and metrics endpoints - mux.HandleFunc("/health", a.healthChecker.HealthHandler()) - mux.HandleFunc("/health/ready", a.healthChecker.ReadyHandler()) - mux.Handle("/metrics", metrics.Handler()) + // Health and metrics endpoints (adapted from net/http) + a.app.Get("/health", adaptor.HTTPHandlerFunc(a.healthChecker.HealthHandler())) + a.app.Get("/health/ready", adaptor.HTTPHandlerFunc(a.healthChecker.ReadyHandler())) + a.app.Get("/metrics", adaptor.HTTPHandler(metrics.Handler())) - // WebSocket endpoint - mux.HandleFunc("/ws", a.wsServer.HandleWebSocket) + // WebSocket endpoint (adapted from net/http) + a.app.Get("/ws", adaptor.HTTPHandlerFunc(a.wsServer.HandleWebSocket)) // API endpoints - mux.HandleFunc("/api/packages/", a.handlePackages) // Handles packages and vulnerabilities - mux.HandleFunc("/api/stats", a.handleStats) - mux.HandleFunc("/api/info", a.handleInfo) + a.app.Get("/api/config", a.handleConfig) + a.app.All("/api/packages/*", a.handlePackages) // Handles packages and vulnerabilities + a.app.Get("/api/stats", a.handleStats) + a.app.Get("/api/stats/timeseries", a.handleTimeSeriesStats) + a.app.Get("/api/info", a.handleInfo) // Admin endpoints (bypass management) - mux.HandleFunc("/api/admin/bypasses/", a.handleBypassByID) // Must come before /api/admin/bypasses - mux.HandleFunc("/api/admin/bypasses", a.handleAdminBypasses) + a.app.All("/api/admin/bypasses/:id?", a.handleAdminBypasses) - // Proxy handlers + // Proxy handlers (adapted from net/http) goProxyHandler := goproxy.New(a.cache, a.networkClient, goproxy.Config{ Upstream: "https://proxy.golang.org", SumDBURL: "https://sum.golang.org", }) - mux.Handle("/go/", http.StripPrefix("/go", goProxyHandler)) + a.app.All("/go/*", adaptor.HTTPHandler(http.StripPrefix("/go", goProxyHandler))) npmProxyHandler := npm.New(a.cache, a.networkClient, npm.Config{ Upstream: "https://registry.npmjs.org", }) - mux.Handle("/npm/", http.StripPrefix("/npm", npmProxyHandler)) + a.app.All("/npm/*", adaptor.HTTPHandler(http.StripPrefix("/npm", npmProxyHandler))) pypiProxyHandler := pypi.New(a.cache, a.networkClient, pypi.Config{ Upstream: "https://pypi.org/simple", }) - mux.Handle("/pypi/", http.StripPrefix("/pypi", pypiProxyHandler)) + a.app.All("/pypi/*", adaptor.HTTPHandler(http.StripPrefix("/pypi", pypiProxyHandler))) // Serve frontend static files frontendDir := "frontend/dist" if _, err := os.Stat(frontendDir); err == nil { log.Info().Str("dir", frontendDir).Msg("Serving frontend static files") - fs := http.FileServer(http.Dir(frontendDir)) - mux.Handle("/", fs) + a.app.Static("/", frontendDir) } else { log.Warn().Msg("Frontend dist directory not found, frontend won't be served") - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprintf(w, ` + a.app.Get("/", func(c *fiber.Ctx) error { + return c.Type("html").SendString(` GoHoarder @@ -287,20 +293,9 @@ func (a *App) setupServer() error { }) } - // Wrap with logging middleware - handler := logger.Middleware(mux) - - // Create HTTP server - a.server = &http.Server{ - Addr: fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port), - Handler: handler, - ReadTimeout: a.config.Server.ReadTimeout, - WriteTimeout: a.config.Server.WriteTimeout, - } - log.Info(). - Str("addr", a.server.Addr). - Msg("HTTP server configured") + Str("addr", fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)). + Msg("Fiber server configured") return nil } @@ -320,13 +315,17 @@ func (a *App) Run() error { go a.rescanWorker.Start(ctx) } - // Start HTTP server in goroutine + // Start download data aggregation worker (runs every hour) + go a.startAggregationWorker(ctx) + + // Start Fiber server in goroutine errChan := make(chan error, 1) go func() { + addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) log.Info(). - Str("addr", a.server.Addr). - Msg("Starting HTTP server") - if err := a.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + Str("addr", addr). + Msg("Starting Fiber server") + if err := a.app.Listen(addr); err != nil { errChan <- err } }() @@ -352,12 +351,9 @@ func (a *App) Run() error { func (a *App) Shutdown() error { log.Info().Msg("Starting graceful shutdown") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Stop HTTP server - if err := a.server.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("Error shutting down HTTP server") + // Stop Fiber server + if err := a.app.Shutdown(); err != nil { + log.Error().Err(err).Msg("Error shutting down Fiber server") } // Stop pre-warming worker @@ -391,3 +387,29 @@ func (a *App) Shutdown() error { log.Info().Msg("Shutdown complete") return nil } + +// startAggregationWorker runs download data aggregation periodically +func (a *App) startAggregationWorker(ctx context.Context) { + log.Info().Msg("Starting download data aggregation worker (runs every hour)") + + // Run immediately on startup + if err := a.metadata.AggregateDownloadData(ctx); err != nil { + log.Error().Err(err).Msg("Failed to run initial download data aggregation") + } + + // Then run every hour + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info().Msg("Aggregation worker stopped") + return + case <-ticker.C: + if err := a.metadata.AggregateDownloadData(ctx); err != nil { + log.Error().Err(err).Msg("Failed to aggregate download data") + } + } + } +} diff --git a/pkg/app/handlers.go b/pkg/app/handlers.go index 14136b5..ac9d55c 100644 --- a/pkg/app/handlers.go +++ b/pkg/app/handlers.go @@ -1,48 +1,45 @@ package app import ( - "net/http" "strings" "time" + "github.com/gofiber/fiber/v2" "github.com/lukaszraczylo/gohoarder/internal/version" - "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/metadata" "github.com/lukaszraczylo/gohoarder/pkg/websocket" "github.com/rs/zerolog/log" ) // handlePackages handles /api/packages endpoint -func (a *App) handlePackages(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") +func (a *App) handlePackages(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, DELETE, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) } // Check if this is a vulnerability endpoint request - if strings.HasSuffix(r.URL.Path, "/vulnerabilities") { - a.handleVulnerabilities(w, r) - return + if strings.HasSuffix(c.Path(), "/vulnerabilities") { + return a.handleVulnerabilities(c) } - switch r.Method { + switch c.Method() { case "GET": - a.handleListPackages(w, r) + return a.handleListPackages(c) case "DELETE": - a.handleDeletePackage(w, r) + return a.handleDeletePackage(c) default: - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) } } // handleListPackages returns list of cached packages -func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleListPackages(c *fiber.Ctx) error { + ctx := c.Context() // Get packages from metadata store allPackages, err := a.metadata.ListPackages(ctx, &metadata.ListOptions{ @@ -51,19 +48,33 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Error().Err(err).Msg("Failed to list packages") - errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list packages"}) } + log.Debug().Int("total_packages_from_db", len(allPackages)).Msg("Retrieved packages from database") + // Filter, clean, and deduplicate packages - seen := make(map[string]*metadata.Package) + // Map stores both cleaned package and original name for scan lookups + type packageEntry struct { + pkg *metadata.Package + originalName string + } + seen := make(map[string]*packageEntry) + skippedCount := 0 for _, pkg := range allPackages { // Skip metadata entries (npm metadata pages, pypi pages, etc.) if pkg.Version == "list" || pkg.Version == "latest" || pkg.Version == "metadata" || pkg.Version == "page" { + skippedCount++ + log.Debug(). + Str("name", pkg.Name). + Str("version", pkg.Version). + Str("registry", pkg.Registry). + Msg("Skipping metadata entry") continue } // Clean the package name (remove /@v/version.ext suffix) + originalName := pkg.Name cleanName := pkg.Name if idx := strings.Index(cleanName, "/@v/"); idx != -1 { cleanName = cleanName[:idx] @@ -73,25 +84,41 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { key := cleanName + "@" + pkg.Version // Keep the entry with the largest size (typically .zip files) - if existing, ok := seen[key]; !ok || pkg.Size > existing.Size { + if existing, ok := seen[key]; !ok || pkg.Size > existing.pkg.Size { // Create a copy with cleaned name cleanPkg := *pkg cleanPkg.Name = cleanName - seen[key] = &cleanPkg + seen[key] = &packageEntry{ + pkg: &cleanPkg, + originalName: originalName, + } } } - // Convert map to slice - packages := make([]*metadata.Package, 0, len(seen)) - for _, pkg := range seen { - packages = append(packages, pkg) + log.Debug(). + Int("skipped_metadata", skippedCount). + Int("unique_packages", len(seen)). + Msg("Filtered and deduplicated packages") + + // Convert map to slice, keeping track of original names + type packageWithOriginalName struct { + pkg *metadata.Package + originalName string + } + packagesWithNames := make([]packageWithOriginalName, 0, len(seen)) + for _, entry := range seen { + packagesWithNames = append(packagesWithNames, packageWithOriginalName{ + pkg: entry.pkg, + originalName: entry.originalName, + }) } // Enhance packages with vulnerability information if security scanning is enabled var response map[string]interface{} if a.config.Security.Enabled { - enhancedPackages := make([]map[string]interface{}, 0, len(packages)) - for _, pkg := range packages { + enhancedPackages := make([]map[string]interface{}, 0, len(packagesWithNames)) + for _, entry := range packagesWithNames { + pkg := entry.pkg pkgMap := map[string]interface{}{ "id": pkg.ID, "registry": pkg.Registry, @@ -106,7 +133,8 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { // Add vulnerability info if scanned if pkg.SecurityScanned { - scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, pkg.Name, pkg.Version) + // Use original name for scan result lookup (handles Go packages with /@v/ suffix) + scanResult, err := a.metadata.GetScanResult(ctx, pkg.Registry, entry.originalName, pkg.Version) if err == nil && scanResult != nil { // Count vulnerabilities by severity severityCounts := make(map[string]int) @@ -115,8 +143,8 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { } pkgMap["vulnerabilities"] = map[string]interface{}{ - "scanned": true, - "status": scanResult.Status, + "scanned": true, + "status": scanResult.Status, "scannedAt": scanResult.ScannedAt.Format(time.RFC3339), "counts": map[string]int{ "critical": severityCounts["CRITICAL"], @@ -147,6 +175,11 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { "total": len(enhancedPackages), } } else { + // Non-enhanced mode - just return the packages + packages := make([]*metadata.Package, 0, len(packagesWithNames)) + for _, entry := range packagesWithNames { + packages = append(packages, entry.pkg) + } response = map[string]interface{}{ "packages": packages, "total": len(packages), @@ -154,21 +187,22 @@ func (a *App) handleListPackages(w http.ResponseWriter, r *http.Request) { } // Success response - errors.WriteJSONSimple(w, http.StatusOK, response) + return c.Status(fiber.StatusOK).JSON(response) } // handleDeletePackage deletes a cached package -func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleDeletePackage(c *fiber.Ctx) error { + ctx := c.Context() // Parse path: /api/packages/{registry}/{name}/{version} // For Go packages, name can contain slashes (e.g., github.com/user/repo) // Version is always the last segment - path := strings.TrimPrefix(r.URL.Path, "/api/packages/") + path := strings.TrimPrefix(c.Path(), "/api/packages/") parts := strings.Split(path, "/") if len(parts) < 3 { - errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid path format, expected /api/packages/{registry}/{name}/{version}", + }) } registry := parts[0] @@ -187,8 +221,7 @@ func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Error().Err(err).Msg("Failed to list packages for deletion") - errors.WriteErrorSimple(w, errors.InternalServer("failed to list packages")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list packages"}) } log.Debug(). @@ -242,13 +275,11 @@ func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) { Msg("Delete operation completed") if deletedCount == 0 { - errors.WriteErrorSimple(w, errors.NotFound("package not found")) - return + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "package not found"}) } if lastErr != nil && deletedCount == 0 { - errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete package"}) } } else { // For NPM and PyPI, delete directly @@ -259,8 +290,7 @@ func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) { Str("name", name). Str("version", version). Msg("Failed to delete package") - errors.WriteErrorSimple(w, errors.InternalServer("failed to delete package")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete package"}) } deletedCount = 1 } @@ -287,46 +317,41 @@ func (a *App) handleDeletePackage(w http.ResponseWriter, r *http.Request) { response["deleted_count"] = deletedCount } - errors.WriteJSONSimple(w, http.StatusOK, response) + return c.Status(fiber.StatusOK).JSON(response) } // handleStats handles /api/stats endpoint -func (a *App) handleStats(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") +func (a *App) handleStats(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) } - if r.Method != "GET" { - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) - return + if c.Method() != "GET" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) } - ctx := r.Context() + ctx := c.Context() - // Get cache statistics for all registries + // Get cache statistics for all registries from database cacheStats, err := a.cache.GetStats(ctx, "") if err != nil { log.Error().Err(err).Msg("Failed to get cache stats") cacheStats = &metadata.Stats{} } - // Get all packages to calculate total size and downloads + // Get all packages to calculate per-registry breakdown packages, err := a.metadata.ListPackages(ctx, nil) if err != nil { log.Error().Err(err).Msg("Failed to list packages") packages = []*metadata.Package{} } - // Calculate totals and registry breakdown from actual packages (exclude metadata entries like "list", "latest") - var totalSize int64 - var totalDownloads int64 - var actualPackageCount int + // Calculate per-registry breakdown (exclude metadata entries like "list", "latest") registryStats := make(map[string]map[string]interface{}) for _, pkg := range packages { @@ -334,9 +359,6 @@ func (a *App) handleStats(w http.ResponseWriter, r *http.Request) { if pkg.Version == "list" || pkg.Version == "latest" || pkg.Version == "metadata" || pkg.Version == "page" { continue } - totalSize += pkg.Size - totalDownloads += int64(pkg.DownloadCount) - actualPackageCount++ // Track per-registry stats if _, ok := registryStats[pkg.Registry]; !ok { @@ -351,11 +373,11 @@ func (a *App) handleStats(w http.ResponseWriter, r *http.Request) { registryStats[pkg.Registry]["downloads"] = registryStats[pkg.Registry]["downloads"].(int64) + int64(pkg.DownloadCount) } - // Combine statistics + // Combine statistics using database stats for accuracy stats := map[string]interface{}{ - "total_packages": actualPackageCount, - "total_downloads": totalDownloads, - "total_size": totalSize, + "total_packages": cacheStats.TotalPackages, + "total_downloads": cacheStats.TotalDownloads, + "total_size": cacheStats.TotalSize, "cache_hits": cacheStats.TotalDownloads, "cache_misses": 0, // TODO: Track cache misses "cache_evictions": 0, // TODO: Track evictions @@ -370,27 +392,102 @@ func (a *App) handleStats(w http.ResponseWriter, r *http.Request) { registries[registry] = regStats } - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ + return c.Status(fiber.StatusOK).JSON(fiber.Map{ "stats": stats, "registries": registries, }) } -// handleInfo handles /api/info endpoint -func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") +// handleTimeSeriesStats handles /api/stats/timeseries endpoint +// Returns time-series download statistics for charts +func (a *App) handleTimeSeriesStats(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) } - if r.Method != "GET" { - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) - return + if c.Method() != "GET" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) + } + + ctx := c.Context() + + // Get query parameters + period := c.Query("period", "1day") // Default to 1 day + registry := c.Query("registry") // Optional registry filter + + // Validate period + validPeriods := map[string]bool{"1h": true, "1day": true, "7day": true, "30day": true} + if !validPeriods[period] { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid period, must be one of: 1h, 1day, 7day, 30day", + }) + } + + // Get time-series stats + stats, err := a.metadata.GetTimeSeriesStats(ctx, period, registry) + if err != nil { + log.Error().Err(err).Str("period", period).Str("registry", registry).Msg("Failed to get time-series stats") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": "failed to get time-series statistics", + }) + } + + return c.Status(fiber.StatusOK).JSON(stats) +} + +// handleConfig handles /api/config endpoint +// Returns runtime configuration for the frontend +func (a *App) handleConfig(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") + + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) + } + + if c.Method() != "GET" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) + } + + // Build server URL from request + scheme := "http" + if c.Protocol() == "https" { + scheme = "https" + } + serverURL := scheme + "://" + c.Hostname() + + config := map[string]interface{}{ + "server_url": serverURL, + "version": version.Version, + "features": map[string]bool{ + "security_scanning": a.config.Security.Enabled, + "websockets": true, + }, + } + + return c.Status(fiber.StatusOK).JSON(config) +} + +// handleInfo handles /api/info endpoint +func (a *App) handleInfo(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") + + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) + } + + if c.Method() != "GET" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) } info := map[string]interface{}{ @@ -411,5 +508,5 @@ func (a *App) handleInfo(w http.ResponseWriter, r *http.Request) { }, } - errors.WriteJSONSimple(w, http.StatusOK, info) + return c.Status(fiber.StatusOK).JSON(info) } diff --git a/pkg/app/handlers_admin.go b/pkg/app/handlers_admin.go index 9765934..14f5396 100644 --- a/pkg/app/handlers_admin.go +++ b/pkg/app/handlers_admin.go @@ -1,111 +1,94 @@ package app import ( - "encoding/json" - "io" - "net/http" "strings" "time" + "github.com/gofiber/fiber/v2" "github.com/lukaszraczylo/gohoarder/pkg/auth" - "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/metadata" "github.com/lukaszraczylo/gohoarder/pkg/uuid" "github.com/rs/zerolog/log" ) // requireAdmin middleware checks for admin authentication -func (a *App) requireAdmin(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Get API key from Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "missing authorization header")) - return - } - - // Extract bearer token - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { - errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "invalid authorization header format, expected: Bearer ")) - return - } - - apiKey := parts[1] - - // Validate API key - key, err := a.authManager.ValidateAPIKey(r.Context(), apiKey) - if err != nil { - errors.WriteErrorSimple(w, errors.New(errors.ErrCodeUnauthorized, "invalid or expired API key")) - return - } - - // Check if user has admin role or bypass management permission - if key.Role != auth.RoleAdmin && !key.HasPermission(auth.PermissionManageBypasses) { - errors.WriteErrorSimple(w, errors.New(errors.ErrCodeForbidden, "insufficient permissions, admin role required")) - return - } - - // Store user info in request context for handlers to use - // For now, we'll just proceed - could enhance with context.WithValue - next(w, r) +func (a *App) requireAdmin(c *fiber.Ctx) error { + // Get API key from Authorization header + authHeader := c.Get("Authorization") + if authHeader == "" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": "missing authorization header", + }) } + + // Extract bearer token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": "invalid authorization header format, expected: Bearer ", + }) + } + + apiKey := parts[1] + + // Validate API key + key, err := a.authManager.ValidateAPIKey(c.Context(), apiKey) + if err != nil { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": "invalid or expired API key", + }) + } + + // Check if user has admin role or bypass management permission + if key.Role != auth.RoleAdmin && !key.HasPermission(auth.PermissionManageBypasses) { + return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ + "error": "insufficient permissions, admin role required", + }) + } + + // Continue to next handler + return c.Next() } // handleAdminBypasses handles /api/admin/bypasses endpoint -func (a *App) handleAdminBypasses(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") +func (a *App) handleAdminBypasses(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) } - switch r.Method { + // Check if there's an ID parameter + id := c.Params("id") + + switch c.Method() { case "GET": - a.requireAdmin(a.handleListBypasses)(w, r) + if id != "" { + return a.handleGetBypass(c) + } + return a.handleListBypasses(c) case "POST": - a.requireAdmin(a.handleCreateBypass)(w, r) - default: - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) - } -} - -// handleBypassByID handles /api/admin/bypasses/{id} endpoint -func (a *App) handleBypassByID(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, DELETE, PATCH, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - switch r.Method { - case "GET": - a.requireAdmin(a.handleGetBypass)(w, r) - case "DELETE": - a.requireAdmin(a.handleDeleteBypass)(w, r) + return a.handleCreateBypass(c) case "PATCH": - a.requireAdmin(a.handleUpdateBypass)(w, r) + return a.handleUpdateBypass(c) + case "DELETE": + return a.handleDeleteBypass(c) default: - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) } } // handleListBypasses lists all CVE bypasses -func (a *App) handleListBypasses(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleListBypasses(c *fiber.Ctx) error { + ctx := c.Context() // Parse query parameters - includeExpired := r.URL.Query().Get("include_expired") == "true" - activeOnly := r.URL.Query().Get("active_only") == "true" - bypassType := metadata.BypassType(r.URL.Query().Get("type")) + includeExpired := c.Query("include_expired") == "true" + activeOnly := c.Query("active_only") == "true" + bypassType := metadata.BypassType(c.Query("type")) opts := &metadata.BypassListOptions{ IncludeExpired: includeExpired, @@ -116,11 +99,10 @@ func (a *App) handleListBypasses(w http.ResponseWriter, r *http.Request) { bypasses, err := a.metadata.ListCVEBypasses(ctx, opts) if err != nil { log.Error().Err(err).Msg("Failed to list CVE bypasses") - errors.WriteErrorSimple(w, errors.InternalServer("failed to list bypasses")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to list bypasses"}) } - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ + return c.Status(fiber.StatusOK).JSON(fiber.Map{ "bypasses": bypasses, "total": len(bypasses), }) @@ -128,57 +110,43 @@ func (a *App) handleListBypasses(w http.ResponseWriter, r *http.Request) { // CreateBypassRequest represents the request body for creating a bypass type CreateBypassRequest struct { - Type metadata.BypassType `json:"type"` // "cve" or "package" - Target string `json:"target"` // CVE ID or package name - Reason string `json:"reason"` // Why this bypass is needed - CreatedBy string `json:"created_by"` // Admin username - ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration - AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package - NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired + Type metadata.BypassType `json:"type"` // "cve" or "package" + Target string `json:"target"` // CVE ID or package name + Reason string `json:"reason"` // Why this bypass is needed + CreatedBy string `json:"created_by"` // Admin username + ExpiresInHours int `json:"expires_in_hours"` // How many hours until expiration + AppliesTo string `json:"applies_to,omitempty"` // Optional: limit CVE bypass to specific package + NotifyOnExpiry bool `json:"notify_on_expiry"` // Send notification when expired } // handleCreateBypass creates a new CVE bypass -func (a *App) handleCreateBypass(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Parse request body - body, err := io.ReadAll(r.Body) - if err != nil { - errors.WriteErrorSimple(w, errors.BadRequest("failed to read request body")) - return - } - defer r.Body.Close() +func (a *App) handleCreateBypass(c *fiber.Ctx) error { + ctx := c.Context() var req CreateBypassRequest - if err := json.Unmarshal(body, &req); err != nil { - errors.WriteErrorSimple(w, errors.BadRequest("invalid JSON in request body")) - return + if err := c.BodyParser(&req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid JSON in request body"}) } // Validate request if req.Type != metadata.BypassTypeCVE && req.Type != metadata.BypassTypePackage { - errors.WriteErrorSimple(w, errors.BadRequest("type must be 'cve' or 'package'")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "type must be 'cve' or 'package'"}) } if req.Target == "" { - errors.WriteErrorSimple(w, errors.BadRequest("target is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "target is required"}) } if req.Reason == "" { - errors.WriteErrorSimple(w, errors.BadRequest("reason is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "reason is required"}) } if req.CreatedBy == "" { - errors.WriteErrorSimple(w, errors.BadRequest("created_by is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "created_by is required"}) } if req.ExpiresInHours <= 0 { - errors.WriteErrorSimple(w, errors.BadRequest("expires_in_hours must be greater than 0")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "expires_in_hours must be greater than 0"}) } // Create bypass @@ -201,8 +169,7 @@ func (a *App) handleCreateBypass(w http.ResponseWriter, r *http.Request) { // Save to database if err := a.metadata.SaveCVEBypass(ctx, bypass); err != nil { log.Error().Err(err).Msg("Failed to save CVE bypass") - errors.WriteErrorSimple(w, errors.InternalServer("failed to create bypass")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to create bypass"}) } log.Info(). @@ -213,23 +180,21 @@ func (a *App) handleCreateBypass(w http.ResponseWriter, r *http.Request) { Time("expires_at", bypass.ExpiresAt). Msg("CVE bypass created") - errors.WriteJSONSimple(w, http.StatusCreated, map[string]interface{}{ + return c.Status(fiber.StatusCreated).JSON(fiber.Map{ "bypass": bypass, "message": "Bypass created successfully", }) } // handleGetBypass gets a specific bypass by ID -func (a *App) handleGetBypass(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleGetBypass(c *fiber.Ctx) error { + ctx := c.Context() - // Extract ID from path - path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/") - bypassID := path + // Extract ID from parameter + bypassID := c.Params("id") if bypassID == "" { - errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"}) } // Get all bypasses and find the one with matching ID @@ -238,20 +203,18 @@ func (a *App) handleGetBypass(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Error().Err(err).Msg("Failed to list bypasses") - errors.WriteErrorSimple(w, errors.InternalServer("failed to get bypass")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to get bypass"}) } for _, bypass := range bypasses { if bypass.ID == bypassID { - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ + return c.Status(fiber.StatusOK).JSON(fiber.Map{ "bypass": bypass, }) - return } } - errors.WriteErrorSimple(w, errors.NotFound("bypass not found")) + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"}) } // UpdateBypassRequest represents the request body for updating a bypass @@ -262,30 +225,19 @@ type UpdateBypassRequest struct { } // handleUpdateBypass updates a bypass (activate/deactivate or extend expiration) -func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleUpdateBypass(c *fiber.Ctx) error { + ctx := c.Context() - // Extract ID from path - path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/") - bypassID := path + // Extract ID from parameter + bypassID := c.Params("id") if bypassID == "" { - errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"}) } - // Parse request body - body, err := io.ReadAll(r.Body) - if err != nil { - errors.WriteErrorSimple(w, errors.BadRequest("failed to read request body")) - return - } - defer r.Body.Close() - var req UpdateBypassRequest - if err := json.Unmarshal(body, &req); err != nil { - errors.WriteErrorSimple(w, errors.BadRequest("invalid JSON in request body")) - return + if err := c.BodyParser(&req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid JSON in request body"}) } // Get current bypass @@ -294,8 +246,7 @@ func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Error().Err(err).Msg("Failed to list bypasses") - errors.WriteErrorSimple(w, errors.InternalServer("failed to get bypass")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to get bypass"}) } var currentBypass *metadata.CVEBypass @@ -307,8 +258,7 @@ func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) { } if currentBypass == nil { - errors.WriteErrorSimple(w, errors.NotFound("bypass not found")) - return + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"}) } // Update fields @@ -327,8 +277,7 @@ func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) { // Save updated bypass if err := a.metadata.SaveCVEBypass(ctx, currentBypass); err != nil { log.Error().Err(err).Msg("Failed to update bypass") - errors.WriteErrorSimple(w, errors.InternalServer("failed to update bypass")) - return + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to update bypass"}) } log.Info(). @@ -336,43 +285,39 @@ func (a *App) handleUpdateBypass(w http.ResponseWriter, r *http.Request) { Bool("active", currentBypass.Active). Msg("CVE bypass updated") - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ + return c.Status(fiber.StatusOK).JSON(fiber.Map{ "bypass": currentBypass, "message": "Bypass updated successfully", }) } // handleDeleteBypass deletes a bypass -func (a *App) handleDeleteBypass(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() +func (a *App) handleDeleteBypass(c *fiber.Ctx) error { + ctx := c.Context() - // Extract ID from path - path := strings.TrimPrefix(r.URL.Path, "/api/admin/bypasses/") - bypassID := path + // Extract ID from parameter + bypassID := c.Params("id") if bypassID == "" { - errors.WriteErrorSimple(w, errors.BadRequest("bypass ID is required")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "bypass ID is required"}) } // Delete bypass if err := a.metadata.DeleteCVEBypass(ctx, bypassID); err != nil { if strings.Contains(err.Error(), "not found") { - errors.WriteErrorSimple(w, errors.NotFound("bypass not found")) - } else { - log.Error().Err(err).Msg("Failed to delete bypass") - errors.WriteErrorSimple(w, errors.InternalServer("failed to delete bypass")) + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "bypass not found"}) } - return + log.Error().Err(err).Msg("Failed to delete bypass") + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "failed to delete bypass"}) } log.Info(). Str("bypass_id", bypassID). Msg("CVE bypass deleted") - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ - "deleted": true, + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "deleted": true, "bypass_id": bypassID, - "message": "Bypass deleted successfully", + "message": "Bypass deleted successfully", }) } diff --git a/pkg/app/handlers_vulnerabilities.go b/pkg/app/handlers_vulnerabilities.go index 108f914..e13e63d 100644 --- a/pkg/app/handlers_vulnerabilities.go +++ b/pkg/app/handlers_vulnerabilities.go @@ -1,40 +1,38 @@ package app import ( - "net/http" "strings" - "github.com/lukaszraczylo/gohoarder/pkg/errors" + "github.com/gofiber/fiber/v2" "github.com/lukaszraczylo/gohoarder/pkg/metadata" "github.com/rs/zerolog/log" ) // handleVulnerabilities handles /api/packages/{registry}/{name}/{version}/vulnerabilities endpoint -func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") +func (a *App) handleVulnerabilities(c *fiber.Ctx) error { + c.Set("Content-Type", "application/json") + c.Set("Access-Control-Allow-Origin", "*") + c.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + c.Set("Access-Control-Allow-Headers", "Content-Type") - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return + if c.Method() == "OPTIONS" { + return c.SendStatus(fiber.StatusOK) } - if r.Method != "GET" { - errors.WriteErrorSimple(w, errors.BadRequest("method not allowed")) - return + if c.Method() != "GET" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "method not allowed"}) } - ctx := r.Context() + ctx := c.Context() // Parse path: /api/packages/{registry}/{name}/{version}/vulnerabilities - path := strings.TrimPrefix(r.URL.Path, "/api/packages/") + path := strings.TrimPrefix(c.Path(), "/api/packages/") path = strings.TrimSuffix(path, "/vulnerabilities") parts := strings.Split(path, "/") if len(parts) < 3 { - errors.WriteErrorSimple(w, errors.BadRequest("invalid path format, expected /api/packages/{registry}/{name}/{version}/vulnerabilities")) - return + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "invalid path format, expected /api/packages/{registry}/{name}/{version}/vulnerabilities", + }) } registry := parts[0] @@ -53,13 +51,12 @@ func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { // Check if package exists pkg, pkgErr := a.metadata.GetPackage(ctx, registry, name, version) if pkgErr != nil { - errors.WriteErrorSimple(w, errors.NotFound("package not found")) - return + return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "package not found"}) } // Package exists but not scanned yet - errors.WriteJSONSimple(w, http.StatusOK, map[string]interface{}{ - "package": map[string]string{ + return c.Status(fiber.StatusOK).JSON(fiber.Map{ + "package": fiber.Map{ "registry": registry, "name": name, "version": version, @@ -71,7 +68,6 @@ func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { "message": "Package not yet scanned for vulnerabilities", "security_scanned": pkg.SecurityScanned, }) - return } // Get active bypasses to show which vulnerabilities are bypassed @@ -135,8 +131,8 @@ func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { } // Build response - response := map[string]interface{}{ - "package": map[string]string{ + response := fiber.Map{ + "package": fiber.Map{ "registry": registry, "name": name, "version": version, @@ -147,7 +143,7 @@ func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { "status": scanResult.Status, "vulnerabilities": enrichedVulns, "vulnerability_count": scanResult.VulnerabilityCount, - "severity_counts": map[string]int{ + "severity_counts": fiber.Map{ "critical": severityCounts["CRITICAL"], "high": severityCounts["HIGH"], "moderate": severityCounts["MODERATE"], @@ -156,5 +152,5 @@ func (a *App) handleVulnerabilities(w http.ResponseWriter, r *http.Request) { "bypassed_count": len(scanResult.Vulnerabilities) - (severityCounts["CRITICAL"] + severityCounts["HIGH"] + severityCounts["MODERATE"] + severityCounts["LOW"]), } - errors.WriteJSONSimple(w, http.StatusOK, response) + return c.Status(fiber.StatusOK).JSON(response) } diff --git a/pkg/config/config.go b/pkg/config/config.go index f27633d..e0e4071 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -125,9 +125,14 @@ type VulnerabilityThresholds struct { // ScannersConfig contains individual scanner configurations type ScannersConfig struct { - Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"` - OSV OSVConfig `mapstructure:"osv" json:"osv"` - Static StaticConfig `mapstructure:"static" json:"static"` + Trivy TrivyConfig `mapstructure:"trivy" json:"trivy"` + OSV OSVConfig `mapstructure:"osv" json:"osv"` + Static StaticConfig `mapstructure:"static" json:"static"` + Grype GrypeConfig `mapstructure:"grype" json:"grype"` + Govulncheck GovulncheckConfig `mapstructure:"govulncheck" json:"govulncheck"` + NpmAudit NpmAuditConfig `mapstructure:"npm_audit" json:"npm_audit"` + PipAudit PipAuditConfig `mapstructure:"pip_audit" json:"pip_audit"` + GHSA GHSAConfig `mapstructure:"ghsa" json:"ghsa"` } // TrivyConfig contains Trivy scanner configuration @@ -153,6 +158,37 @@ type StaticConfig struct { AllowedLicenses []string `mapstructure:"allowed_licenses" json:"allowed_licenses"` } +// GrypeConfig contains Grype scanner configuration +type GrypeConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` +} + +// GovulncheckConfig contains govulncheck scanner configuration +type GovulncheckConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` +} + +// NpmAuditConfig contains npm audit scanner configuration +type NpmAuditConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` +} + +// PipAuditConfig contains pip-audit scanner configuration +type PipAuditConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` +} + +// GHSAConfig contains GitHub Advisory Database scanner configuration +type GHSAConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Timeout time.Duration `mapstructure:"timeout" json:"timeout"` + Token string `mapstructure:"token" json:"-"` // GitHub token for higher rate limits (don't serialize) +} + // AuthConfig contains authentication configuration type AuthConfig struct { Enabled bool `mapstructure:"enabled" json:"enabled"` @@ -287,6 +323,27 @@ func Default() *Config { CheckChecksums: true, BlockSuspicious: false, }, + Grype: GrypeConfig{ + Enabled: false, + Timeout: 5 * time.Minute, + }, + Govulncheck: GovulncheckConfig{ + Enabled: false, + Timeout: 5 * time.Minute, + }, + NpmAudit: NpmAuditConfig{ + Enabled: false, + Timeout: 2 * time.Minute, + }, + PipAudit: PipAuditConfig{ + Enabled: false, + Timeout: 2 * time.Minute, + }, + GHSA: GHSAConfig{ + Enabled: false, + Timeout: 30 * time.Second, + Token: "", + }, }, }, Auth: AuthConfig{ diff --git a/pkg/metadata/file/file.go b/pkg/metadata/file/file.go index 5f14e43..14d1af6 100644 --- a/pkg/metadata/file/file.go +++ b/pkg/metadata/file/file.go @@ -521,6 +521,24 @@ func (s *Store) CleanupExpiredBypasses(ctx context.Context) (int, error) { return count, nil } +// GetTimeSeriesStats returns time-series download statistics +// File-based store doesn't support time-series statistics +func (s *Store) GetTimeSeriesStats(ctx context.Context, period string, registry string) (*metadata.TimeSeriesStats, error) { + // Return empty time-series data for file-based store + return &metadata.TimeSeriesStats{ + Period: period, + Registry: registry, + DataPoints: []*metadata.TimeSeriesDataPoint{}, + }, nil +} + +// AggregateDownloadData aggregates download data +// File-based store doesn't support aggregation +func (s *Store) AggregateDownloadData(ctx context.Context) error { + // No-op for file-based store + return nil +} + // Close closes the store func (s *Store) Close() error { // Nothing to close for file-based store diff --git a/pkg/metadata/interface.go b/pkg/metadata/interface.go index e2e84cf..c0df62c 100644 --- a/pkg/metadata/interface.go +++ b/pkg/metadata/interface.go @@ -56,6 +56,12 @@ type MetadataStore interface { // Health checks metadata store health Health(ctx context.Context) error + // GetTimeSeriesStats returns time-series download statistics + GetTimeSeriesStats(ctx context.Context, period string, registry string) (*TimeSeriesStats, error) + + // AggregateDownloadData aggregates raw download events and cleans up old data + AggregateDownloadData(ctx context.Context) error + // Close closes the metadata store Close() error } @@ -144,6 +150,19 @@ type Stats struct { LastUpdated time.Time `json:"last_updated"` } +// TimeSeriesDataPoint represents a single data point in time-series +type TimeSeriesDataPoint struct { + Timestamp time.Time `json:"timestamp"` + Value int64 `json:"value"` +} + +// TimeSeriesStats represents time-series download statistics +type TimeSeriesStats struct { + Period string `json:"period"` // 1h, 1day, 7day, 30day + Registry string `json:"registry"` // empty string for all registries + DataPoints []*TimeSeriesDataPoint `json:"data_points"` +} + // CVEBypass represents a temporary bypass for a CVE or package type CVEBypass struct { ID string `json:"id"` // Unique bypass ID diff --git a/pkg/metadata/sqlite/sqlite.go b/pkg/metadata/sqlite/sqlite.go index 132f3ab..390880c 100644 --- a/pkg/metadata/sqlite/sqlite.go +++ b/pkg/metadata/sqlite/sqlite.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" "sync" "time" @@ -89,6 +90,32 @@ CREATE INDEX IF NOT EXISTS idx_cve_bypasses_type ON cve_bypasses(type); CREATE INDEX IF NOT EXISTS idx_cve_bypasses_target ON cve_bypasses(target); CREATE INDEX IF NOT EXISTS idx_cve_bypasses_expires_at ON cve_bypasses(expires_at); CREATE INDEX IF NOT EXISTS idx_cve_bypasses_active ON cve_bypasses(active); + +CREATE TABLE IF NOT EXISTS download_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + registry TEXT NOT NULL, + package_name TEXT NOT NULL, + package_version TEXT NOT NULL, + downloaded_at DATETIME NOT NULL, + FOREIGN KEY(registry, package_name, package_version) REFERENCES packages(registry, name, version) +); + +CREATE INDEX IF NOT EXISTS idx_download_events_registry ON download_events(registry); +CREATE INDEX IF NOT EXISTS idx_download_events_downloaded_at ON download_events(downloaded_at); +CREATE INDEX IF NOT EXISTS idx_download_events_package ON download_events(registry, package_name, package_version); + +CREATE TABLE IF NOT EXISTS aggregated_download_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + registry TEXT NOT NULL, + time_bucket DATETIME NOT NULL, + resolution TEXT NOT NULL, + download_count INTEGER NOT NULL, + UNIQUE(registry, time_bucket, resolution) +); + +CREATE INDEX IF NOT EXISTS idx_aggregated_stats_registry ON aggregated_download_stats(registry); +CREATE INDEX IF NOT EXISTS idx_aggregated_stats_time_bucket ON aggregated_download_stats(time_bucket); +CREATE INDEX IF NOT EXISTS idx_aggregated_stats_resolution ON aggregated_download_stats(resolution); ` // New creates a new SQLite metadata store @@ -340,23 +367,47 @@ func (s *SQLiteStore) ListPackages(ctx context.Context, opts *metadata.ListOptio return packages, nil } -// UpdateDownloadCount increments download counter +// UpdateDownloadCount increments download counter and records download event func (s *SQLiteStore) UpdateDownloadCount(ctx context.Context, registry, name, version string) error { s.mu.Lock() defer s.mu.Unlock() - query := ` + now := time.Now() + + // Start transaction + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to start transaction") + } + defer tx.Rollback() + + // Update download count + updateQuery := ` UPDATE packages SET download_count = download_count + 1, last_accessed = ? WHERE registry = ? AND name = ? AND version = ? ` - - _, err := s.db.ExecContext(ctx, query, time.Now(), registry, name, version) + _, err = tx.ExecContext(ctx, updateQuery, now, registry, name, version) if err != nil { return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to update download count") } + // Record download event for time-series statistics + insertQuery := ` + INSERT INTO download_events (registry, package_name, package_version, downloaded_at) + VALUES (?, ?, ?, ?) + ` + _, err = tx.ExecContext(ctx, insertQuery, registry, name, version, now) + if err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to record download event") + } + + // Commit transaction + if err := tx.Commit(); err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to commit transaction") + } + return nil } @@ -372,11 +423,12 @@ func (s *SQLiteStore) GetStats(ctx context.Context, registry string) (*metadata. COALESCE(SUM(download_count), 0) as total_downloads, COALESCE(SUM(CASE WHEN security_scanned = 1 THEN 1 ELSE 0 END), 0) as scanned_packages FROM packages + WHERE version NOT IN ('list', 'latest', 'metadata', 'page') ` args := []interface{}{} if registry != "" { - query += " WHERE registry = ?" + query += " AND registry = ?" args = append(args, registry) } @@ -408,6 +460,257 @@ func (s *SQLiteStore) GetStats(ctx context.Context, registry string) (*metadata. return &stats, nil } +// GetTimeSeriesStats returns time-series download statistics +// Uses different data sources based on period for efficiency: +// - 1h: raw download_events (last hour only) +// - 1day: hourly aggregates +// - 7day, 30day: daily aggregates +func (s *SQLiteStore) GetTimeSeriesStats(ctx context.Context, period string, registry string) (*metadata.TimeSeriesStats, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var ( + timeFormat string + startTime time.Time + bucketCount int + useRawEvents bool + useResolution string + ) + + now := time.Now() + + // Determine time range, bucket size, and data source based on period + switch period { + case "1h": + startTime = now.Add(-1 * time.Hour) + timeFormat = "%Y-%m-%d %H:%M:00" // 5-minute buckets + bucketCount = 12 // 12 x 5min = 60min + useRawEvents = true // Use raw events for last hour + case "1day": + startTime = now.Add(-24 * time.Hour) + timeFormat = "%Y-%m-%d %H:00:00" // hourly buckets + bucketCount = 24 + useResolution = "hourly" // Use hourly aggregates + case "7day": + startTime = now.Add(-7 * 24 * time.Hour) + timeFormat = "%Y-%m-%d 00:00:00" // daily buckets + bucketCount = 7 + useResolution = "daily" // Use daily aggregates + case "30day": + startTime = now.Add(-30 * 24 * time.Hour) + timeFormat = "%Y-%m-%d 00:00:00" // daily buckets + bucketCount = 30 + useResolution = "daily" // Use daily aggregates + default: + return nil, errors.New(errors.ErrCodeBadRequest, "invalid period, must be one of: 1h, 1day, 7day, 30day") + } + + var query string + var args []interface{} + + if useRawEvents { + // Query raw download_events for 1h period + query = ` + SELECT + strftime(?, downloaded_at) as time_bucket, + COUNT(*) as download_count + FROM download_events + WHERE downloaded_at >= ? + ` + args = []interface{}{timeFormat, startTime} + + if registry != "" { + query += " AND registry = ?" + args = append(args, registry) + } + + query += ` + GROUP BY time_bucket + ORDER BY time_bucket ASC + ` + } else { + // Query aggregated_download_stats for longer periods + query = ` + SELECT + time_bucket, + SUM(download_count) as download_count + FROM aggregated_download_stats + WHERE resolution = ? AND time_bucket >= ? + ` + args = []interface{}{useResolution, startTime} + + if registry != "" { + query += " AND registry = ?" + args = append(args, registry) + } + + query += ` + GROUP BY time_bucket + ORDER BY time_bucket ASC + ` + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to query time-series stats") + } + defer rows.Close() + + // Collect data points + dataMap := make(map[string]int64) + for rows.Next() { + var bucket string + var count int64 + if err := rows.Scan(&bucket, &count); err != nil { + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to scan time-series data") + } + dataMap[bucket] = count + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "error iterating time-series data") + } + + // Create complete data points array with zeros for missing buckets + dataPoints := make([]*metadata.TimeSeriesDataPoint, 0, bucketCount) + + // Generate all expected buckets + currentTime := startTime + var increment time.Duration + switch period { + case "1h": + increment = 5 * time.Minute + case "1day": + increment = time.Hour + case "7day", "30day": + increment = 24 * time.Hour + } + + for i := 0; i < bucketCount; i++ { + var bucket string + if useRawEvents { + bucket = currentTime.Format(convertGoTimeFormat(timeFormat)) + } else { + // For aggregated data, time_bucket is already in the right format + bucket = currentTime.Format("2006-01-02 15:04:05") + } + count := dataMap[bucket] + + dataPoints = append(dataPoints, &metadata.TimeSeriesDataPoint{ + Timestamp: currentTime, + Value: count, + }) + + currentTime = currentTime.Add(increment) + } + + return &metadata.TimeSeriesStats{ + Period: period, + Registry: registry, + DataPoints: dataPoints, + }, nil +} + +// convertGoTimeFormat converts SQLite strftime format to Go time format +func convertGoTimeFormat(sqliteFormat string) string { + // SQLite strftime to Go time.Format mapping + format := sqliteFormat + format = strings.ReplaceAll(format, "%Y", "2006") + format = strings.ReplaceAll(format, "%m", "01") + format = strings.ReplaceAll(format, "%d", "02") + format = strings.ReplaceAll(format, "%H", "15") + format = strings.ReplaceAll(format, "%M", "04") + format = strings.ReplaceAll(format, "%S", "05") + return format +} + +// AggregateDownloadData aggregates raw download events into hourly/daily buckets and cleans up old data +// This should be called periodically (e.g., every hour) as a background job +func (s *SQLiteStore) AggregateDownloadData(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + log.Info().Msg("Starting download data aggregation") + + // Start transaction + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to start aggregation transaction") + } + defer tx.Rollback() + + now := time.Now() + oneHourAgo := now.Add(-1 * time.Hour) + oneDayAgo := now.Add(-24 * time.Hour) + + // Step 1: Aggregate raw events older than 1 hour into hourly buckets + // Group by registry and hour, then insert into aggregated_download_stats + hourlyAggQuery := ` + INSERT OR REPLACE INTO aggregated_download_stats (registry, time_bucket, resolution, download_count) + SELECT + registry, + strftime('%Y-%m-%d %H:00:00', downloaded_at) as time_bucket, + 'hourly' as resolution, + COUNT(*) as download_count + FROM download_events + WHERE downloaded_at < ? + GROUP BY registry, time_bucket + ` + _, err = tx.ExecContext(ctx, hourlyAggQuery, oneHourAgo) + if err != nil { + log.Error().Err(err).Msg("Failed to aggregate hourly data") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to aggregate hourly download data") + } + + // Step 2: Delete raw events older than 1 hour (they're now aggregated) + deleteRawQuery := `DELETE FROM download_events WHERE downloaded_at < ?` + result, err := tx.ExecContext(ctx, deleteRawQuery, oneHourAgo) + if err != nil { + log.Error().Err(err).Msg("Failed to delete old raw events") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete old download events") + } + rawDeleted, _ := result.RowsAffected() + + // Step 3: Aggregate hourly stats older than 24 hours into daily buckets + dailyAggQuery := ` + INSERT OR REPLACE INTO aggregated_download_stats (registry, time_bucket, resolution, download_count) + SELECT + registry, + strftime('%Y-%m-%d 00:00:00', time_bucket) as time_bucket, + 'daily' as resolution, + SUM(download_count) as download_count + FROM aggregated_download_stats + WHERE resolution = 'hourly' AND time_bucket < ? + GROUP BY registry, strftime('%Y-%m-%d 00:00:00', time_bucket) + ` + _, err = tx.ExecContext(ctx, dailyAggQuery, oneDayAgo) + if err != nil { + log.Error().Err(err).Msg("Failed to aggregate daily data") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to aggregate daily download data") + } + + // Step 4: Delete hourly stats older than 24 hours (they're now aggregated into daily) + deleteHourlyQuery := `DELETE FROM aggregated_download_stats WHERE resolution = 'hourly' AND time_bucket < ?` + result, err = tx.ExecContext(ctx, deleteHourlyQuery, oneDayAgo) + if err != nil { + log.Error().Err(err).Msg("Failed to delete old hourly aggregates") + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete old hourly aggregates") + } + hourlyDeleted, _ := result.RowsAffected() + + // Commit transaction + if err := tx.Commit(); err != nil { + return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to commit aggregation transaction") + } + + log.Info(). + Int64("raw_events_deleted", rawDeleted). + Int64("hourly_aggregates_deleted", hourlyDeleted). + Msg("Download data aggregation completed successfully") + + return nil +} + // SaveScanResult saves security scan result func (s *SQLiteStore) SaveScanResult(ctx context.Context, result *metadata.ScanResult) error { s.mu.Lock() diff --git a/pkg/proxy/goproxy/goproxy.go b/pkg/proxy/goproxy/goproxy.go index b46dec5..16cb4c1 100644 --- a/pkg/proxy/goproxy/goproxy.go +++ b/pkg/proxy/goproxy/goproxy.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/lukaszraczylo/gohoarder/pkg/cache" + "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/network" "github.com/rs/zerolog/log" ) @@ -194,6 +195,14 @@ func (h *Handler) handleZip(ctx context.Context, w http.ResponseWriter, r *http. if err != nil { log.Error().Err(err).Str("url", url).Msg("Failed to fetch module zip") + + // Check if error is a security violation - return 403 Forbidden + if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation { + http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden) + return + } + + // All other errors return 502 Bad Gateway (upstream issues) http.Error(w, "Failed to fetch module zip", http.StatusBadGateway) return } diff --git a/pkg/proxy/npm/npm.go b/pkg/proxy/npm/npm.go index 975f9fb..504fa43 100644 --- a/pkg/proxy/npm/npm.go +++ b/pkg/proxy/npm/npm.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/lukaszraczylo/gohoarder/pkg/cache" + "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/network" "github.com/rs/zerolog/log" ) @@ -148,6 +149,14 @@ func (h *Handler) handleTarball(ctx context.Context, w http.ResponseWriter, r *h if err != nil { log.Error().Err(err).Str("url", url).Msg("Failed to fetch package tarball") + + // Check if error is a security violation - return 403 Forbidden + if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation { + http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden) + return + } + + // All other errors return 502 Bad Gateway (upstream issues) http.Error(w, "Failed to fetch package tarball", http.StatusBadGateway) return } diff --git a/pkg/proxy/pypi/pypi.go b/pkg/proxy/pypi/pypi.go index 9deb0c2..c7d8b5b 100644 --- a/pkg/proxy/pypi/pypi.go +++ b/pkg/proxy/pypi/pypi.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/lukaszraczylo/gohoarder/pkg/cache" + "github.com/lukaszraczylo/gohoarder/pkg/errors" "github.com/lukaszraczylo/gohoarder/pkg/network" "github.com/rs/zerolog/log" ) @@ -43,6 +44,8 @@ func New(cacheManager *cache.Manager, client *network.Client, config Config) *Ha func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() path := strings.TrimPrefix(r.URL.Path, "/pypi") + // Also trim /simple prefix since upstream already includes it + path = strings.TrimPrefix(path, "/simple") log.Debug().Str("path", path).Str("method", r.Method).Msg("PyPI proxy request") @@ -163,6 +166,14 @@ func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, if err != nil { log.Error().Err(err).Str("url", originalURL).Msg("Failed to fetch package file") + + // Check if error is a security violation - return 403 Forbidden + if ghErr, ok := err.(*errors.Error); ok && ghErr.Code == errors.ErrCodeSecurityViolation { + http.Error(w, fmt.Sprintf("Package blocked: %s", ghErr.Message), http.StatusForbidden) + return + } + + // All other errors return 502 Bad Gateway (upstream issues) http.Error(w, "Failed to fetch package file", http.StatusBadGateway) return } diff --git a/pkg/scanner/ghsa/ghsa.go b/pkg/scanner/ghsa/ghsa.go new file mode 100644 index 0000000..7d3ae36 --- /dev/null +++ b/pkg/scanner/ghsa/ghsa.go @@ -0,0 +1,287 @@ +package ghsa + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/lukaszraczylo/gohoarder/pkg/config" + "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/uuid" + "github.com/rs/zerolog/log" +) + +// ScannerName is the name of this scanner +const ScannerName = "github-advisory-database" + +// Scanner implements the GitHub Advisory Database vulnerability scanner +type Scanner struct { + config config.GHSAConfig + httpClient *http.Client +} + +// New creates a new GitHub Advisory Database scanner +func New(cfg config.GHSAConfig) *Scanner { + return &Scanner{ + config: cfg, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// Name returns the scanner name +func (s *Scanner) Name() string { + return ScannerName +} + +// Scan scans a package using GitHub Advisory Database API +func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) { + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Str("version", version). + Str("registry", registry). + Msg("Starting GitHub Advisory Database scan") + + // Map registry to GitHub ecosystem + ecosystem := s.mapRegistryToEcosystem(registry) + if ecosystem == "" { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{ + "skipped": fmt.Sprintf("GitHub Advisory Database does not support registry: %s", registry), + }, + }, nil + } + + // Query GitHub Advisory Database + advisories, err := s.queryAdvisories(ctx, ecosystem, packageName) + if err != nil { + log.Warn().Err(err).Msg("Failed to query GitHub Advisory Database") + return s.emptyResult(registry, packageName, version), nil + } + + // Filter advisories that affect this version + affectedAdvisories := s.filterAffectedAdvisories(advisories, version) + + // Convert to our format + result := s.convertResult(affectedAdvisories, registry, packageName, version) + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Int("vulnerabilities", result.VulnerabilityCount). + Msg("GitHub Advisory Database scan completed") + + return result, nil +} + +// Health checks if GitHub API is accessible +func (s *Scanner) Health(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/advisories", nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.github+json") + if s.config.Token != "" { + req.Header.Set("Authorization", "Bearer "+s.config.Token) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return fmt.Errorf("github advisory database not accessible: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("github api returned status: %d", resp.StatusCode) + } + + return nil +} + +// mapRegistryToEcosystem maps our registry names to GitHub ecosystem names +func (s *Scanner) mapRegistryToEcosystem(registry string) string { + mapping := map[string]string{ + "npm": "npm", + "pypi": "pip", + "go": "go", + "maven": "maven", + "nuget": "nuget", + "cargo": "cargo", + "pub": "pub", + } + return mapping[strings.ToLower(registry)] +} + +// queryAdvisories queries GitHub Advisory Database for a package +func (s *Scanner) queryAdvisories(ctx context.Context, ecosystem, packageName string) ([]GHSAAdvisory, error) { + url := fmt.Sprintf("https://api.github.com/advisories?ecosystem=%s&affects=%s&per_page=100", ecosystem, packageName) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.github+json") + if s.config.Token != "" { + req.Header.Set("Authorization", "Bearer "+s.config.Token) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to query advisories: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("github api returned status %d: %s", resp.StatusCode, string(body)) + } + + var advisories []GHSAAdvisory + if err := json.NewDecoder(resp.Body).Decode(&advisories); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return advisories, nil +} + +// filterAffectedAdvisories filters advisories that affect the given version +func (s *Scanner) filterAffectedAdvisories(advisories []GHSAAdvisory, version string) []GHSAAdvisory { + affected := make([]GHSAAdvisory, 0) + + for _, advisory := range advisories { + // Check if this version is affected + // GitHub API already filters by package, but we need to check version ranges + // For now, we'll include all advisories that match the package + // A more sophisticated implementation would parse version ranges + affected = append(affected, advisory) + } + + return affected +} + +// emptyResult returns an empty scan result +func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{}, + } +} + +// convertResult converts GitHub Advisory Database results to our ScanResult format +func (s *Scanner) convertResult(advisories []GHSAAdvisory, registry, packageName, version string) *metadata.ScanResult { + vulnerabilities := make([]metadata.Vulnerability, 0) + severityCounts := make(map[string]int) + + for _, advisory := range advisories { + // Normalize severity + normalizedSeverity := metadata.NormalizeSeverity(advisory.Severity) + severityCounts[normalizedSeverity]++ + + // Extract references + refs := make([]string, 0) + if advisory.HTMLURL != "" { + refs = append(refs, advisory.HTMLURL) + } + for _, ref := range advisory.References { + if ref.URL != "" { + refs = append(refs, ref.URL) + } + } + + // Get fixed versions + fixedIn := "" + for _, vuln := range advisory.Vulnerabilities { + if vuln.FirstPatchedVersion != nil && vuln.FirstPatchedVersion.Identifier != "" { + fixedIn = vuln.FirstPatchedVersion.Identifier + break + } + } + + vulnerabilities = append(vulnerabilities, metadata.Vulnerability{ + ID: advisory.GHSAID, + Severity: normalizedSeverity, + Title: advisory.Summary, + Description: advisory.Description, + References: refs, + FixedIn: fixedIn, + }) + } + + status := metadata.ScanStatusClean + if len(vulnerabilities) > 0 { + status = metadata.ScanStatusVulnerable + } + + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: status, + VulnerabilityCount: len(vulnerabilities), + Vulnerabilities: vulnerabilities, + Details: map[string]interface{}{ + "severity_counts": severityCounts, + }, + } +} + +// GHSAAdvisory represents a GitHub Security Advisory +type GHSAAdvisory struct { + GHSAID string `json:"ghsa_id"` + CVEID string `json:"cve_id"` + Summary string `json:"summary"` + Description string `json:"description"` + Severity string `json:"severity"` + HTMLURL string `json:"html_url"` + References []GHSAReference `json:"references"` + Vulnerabilities []GHSAVulnerability `json:"vulnerabilities"` + PublishedAt string `json:"published_at"` + UpdatedAt string `json:"updated_at"` +} + +type GHSAReference struct { + URL string `json:"url"` +} + +type GHSAVulnerability struct { + Package GHSAPackage `json:"package"` + VulnerableVersions string `json:"vulnerable_version_range"` + FirstPatchedVersion *GHSAPatchVersion `json:"first_patched_version"` +} + +type GHSAPackage struct { + Ecosystem string `json:"ecosystem"` + Name string `json:"name"` +} + +type GHSAPatchVersion struct { + Identifier string `json:"identifier"` +} diff --git a/pkg/scanner/govulncheck/govulncheck.go b/pkg/scanner/govulncheck/govulncheck.go new file mode 100644 index 0000000..f324960 --- /dev/null +++ b/pkg/scanner/govulncheck/govulncheck.go @@ -0,0 +1,194 @@ +package govulncheck + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "strings" + "time" + + "github.com/lukaszraczylo/gohoarder/pkg/config" + "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/uuid" + "github.com/rs/zerolog/log" +) + +// ScannerName is the name of this scanner +const ScannerName = "govulncheck" + +// Scanner implements the govulncheck vulnerability scanner for Go modules +type Scanner struct { + config config.GovulncheckConfig +} + +// New creates a new govulncheck scanner +func New(cfg config.GovulncheckConfig) *Scanner { + return &Scanner{ + config: cfg, + } +} + +// Name returns the scanner name +func (s *Scanner) Name() string { + return ScannerName +} + +// Scan scans a Go module using govulncheck +func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) { + // Only scan Go packages + if registry != "go" { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{ + "skipped": "govulncheck only supports Go modules", + }, + }, nil + } + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Str("version", version). + Msg("Starting govulncheck scan") + + // Create a temporary directory for extraction + tmpDir, err := os.MkdirTemp("", "govulncheck-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + // Extract the .zip file + if err := s.extractZip(filePath, tmpDir); err != nil { + return nil, fmt.Errorf("failed to extract zip: %w", err) + } + + // Run govulncheck + cmd := exec.CommandContext(ctx, "govulncheck", "-json", "-mode=binary", tmpDir) + output, err := cmd.CombinedOutput() + + // govulncheck returns non-zero when vulnerabilities are found + // Parse output regardless of error + var vulns []GovulncheckVuln + if len(output) > 0 { + // Parse line-delimited JSON + lines := strings.Split(string(output), "\n") + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + var entry GovulncheckEntry + if err := json.Unmarshal([]byte(line), &entry); err != nil { + log.Warn().Err(err).Str("line", line).Msg("Failed to parse govulncheck line") + continue + } + if entry.Finding != nil && entry.Finding.OSV != "" { + vulns = append(vulns, GovulncheckVuln{ + OSV: entry.Finding.OSV, + FixedVersion: entry.Finding.FixedVersion, + }) + } + } + } + + // Convert to our format + result := s.convertResult(vulns, registry, packageName, version) + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Int("vulnerabilities", result.VulnerabilityCount). + Msg("govulncheck scan completed") + + return result, nil +} + +// Health checks if govulncheck is available +func (s *Scanner) Health(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "govulncheck", "-version") + if err := cmd.Run(); err != nil { + return fmt.Errorf("govulncheck not available: %w (install with: go install golang.org/x/vuln/cmd/govulncheck@latest)", err) + } + return nil +} + +// extractZip extracts a zip file to destination +func (s *Scanner) extractZip(zipPath, destDir string) error { + cmd := exec.Command("unzip", "-q", zipPath, "-d", destDir) + return cmd.Run() +} + +// convertResult converts govulncheck findings to our ScanResult format +func (s *Scanner) convertResult(vulns []GovulncheckVuln, registry, packageName, version string) *metadata.ScanResult { + vulnerabilities := make([]metadata.Vulnerability, 0) + severityCounts := make(map[string]int) + seen := make(map[string]bool) + + for _, vuln := range vulns { + // Deduplicate by OSV ID + if seen[vuln.OSV] { + continue + } + seen[vuln.OSV] = true + + // govulncheck doesn't provide severity in output + // Default to HIGH for found vulnerabilities + severity := metadata.NormalizeSeverity("HIGH") + severityCounts[severity]++ + + vulnerabilities = append(vulnerabilities, metadata.Vulnerability{ + ID: vuln.OSV, + Severity: severity, + Title: vuln.OSV, + Description: fmt.Sprintf("Vulnerability %s found by govulncheck", vuln.OSV), + References: []string{fmt.Sprintf("https://pkg.go.dev/vuln/%s", vuln.OSV)}, + FixedIn: vuln.FixedVersion, + }) + } + + status := metadata.ScanStatusClean + if len(vulnerabilities) > 0 { + status = metadata.ScanStatusVulnerable + } + + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: status, + VulnerabilityCount: len(vulnerabilities), + Vulnerabilities: vulnerabilities, + Details: map[string]interface{}{ + "severity_counts": severityCounts, + "note": "govulncheck provides reachability analysis for Go modules", + }, + } +} + +// GovulncheckEntry represents a single line of govulncheck JSON output +type GovulncheckEntry struct { + Finding *GovulncheckFinding `json:"finding,omitempty"` +} + +type GovulncheckFinding struct { + OSV string `json:"osv"` + FixedVersion string `json:"fixed_version,omitempty"` +} + +type GovulncheckVuln struct { + OSV string + FixedVersion string +} diff --git a/pkg/scanner/grype/grype.go b/pkg/scanner/grype/grype.go new file mode 100644 index 0000000..3f58895 --- /dev/null +++ b/pkg/scanner/grype/grype.go @@ -0,0 +1,193 @@ +package grype + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "time" + + "github.com/lukaszraczylo/gohoarder/pkg/config" + "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/uuid" + "github.com/rs/zerolog/log" +) + +// ScannerName is the name of this scanner +const ScannerName = "grype" + +// Scanner implements the Grype vulnerability scanner +type Scanner struct { + config config.GrypeConfig +} + +// New creates a new Grype scanner +func New(cfg config.GrypeConfig) *Scanner { + return &Scanner{ + config: cfg, + } +} + +// Name returns the scanner name +func (s *Scanner) Name() string { + return ScannerName +} + +// Scan scans a package using Grype +func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) { + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Str("version", version). + Str("file", filePath). + Msg("Starting Grype scan") + + // Run grype scan + cmd := exec.CommandContext(ctx, "grype", filePath, "-o", "json", "-q") + output, err := cmd.CombinedOutput() + if err != nil { + // Grype returns non-zero exit code when vulnerabilities are found + // Only treat it as error if we got no output + if len(output) == 0 { + return nil, fmt.Errorf("grype scan failed: %w (output: %s)", err, string(output)) + } + } + + // Parse Grype JSON output + var grypeResult GrypeResult + if err := json.Unmarshal(output, &grypeResult); err != nil { + return nil, fmt.Errorf("failed to parse grype output: %w", err) + } + + // Convert to our format + result := s.convertGrypeResult(&grypeResult, registry, packageName, version) + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Int("vulnerabilities", result.VulnerabilityCount). + Msg("Grype scan completed") + + return result, nil +} + +// Health checks if Grype is available +func (s *Scanner) Health(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "grype", "version") + if err := cmd.Run(); err != nil { + return fmt.Errorf("grype not available: %w", err) + } + return nil +} + +// UpdateDatabase updates Grype's vulnerability database +func (s *Scanner) UpdateDatabase(ctx context.Context) error { + log.Info().Str("scanner", ScannerName).Msg("Updating Grype database") + + cmd := exec.CommandContext(ctx, "grype", "db", "update") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to update grype database: %w (output: %s)", err, string(output)) + } + + log.Info().Str("scanner", ScannerName).Msg("Grype database updated successfully") + return nil +} + +// convertGrypeResult converts Grype output to our ScanResult format +func (s *Scanner) convertGrypeResult(grypeResult *GrypeResult, registry, packageName, version string) *metadata.ScanResult { + vulnerabilities := make([]metadata.Vulnerability, 0) + severityCounts := make(map[string]int) + + // Process each vulnerability match + for _, match := range grypeResult.Matches { + // Normalize severity + normalizedSeverity := metadata.NormalizeSeverity(match.Vulnerability.Severity) + + // Count by severity + severityCounts[normalizedSeverity]++ + + // Extract fixed version + fixedIn := "" + if match.Vulnerability.Fix.State == "fixed" { + for _, version := range match.Vulnerability.Fix.Versions { + if fixedIn == "" { + fixedIn = version + } + } + } + + // Add to vulnerabilities list + vulnerabilities = append(vulnerabilities, metadata.Vulnerability{ + ID: match.Vulnerability.ID, + Severity: normalizedSeverity, + Title: match.Vulnerability.ID, // Grype doesn't have separate title + Description: match.Vulnerability.Description, + References: match.Vulnerability.URLs, + FixedIn: fixedIn, + }) + } + + // Determine overall status + status := metadata.ScanStatusClean + if len(vulnerabilities) > 0 { + status = metadata.ScanStatusVulnerable + } + + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: status, + VulnerabilityCount: len(vulnerabilities), + Vulnerabilities: vulnerabilities, + Details: map[string]interface{}{ + "severity_counts": severityCounts, + "grype_version": grypeResult.Descriptor.Version, + }, + } +} + +// GrypeResult represents Grype JSON output structure +type GrypeResult struct { + Matches []GrypeMatch `json:"matches"` + Descriptor GrypeDescriptor `json:"descriptor"` + Source GrypeSource `json:"source"` +} + +type GrypeDescriptor struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type GrypeSource struct { + Type string `json:"type"` + Target map[string]interface{} `json:"target"` +} + +type GrypeMatch struct { + Vulnerability GrypeVulnerability `json:"vulnerability"` + Artifact GrypeArtifact `json:"artifact"` +} + +type GrypeVulnerability struct { + ID string `json:"id"` + Severity string `json:"severity"` + Description string `json:"description"` + URLs []string `json:"urls"` + Fix GrypeFix `json:"fix"` +} + +type GrypeFix struct { + State string `json:"state"` + Versions []string `json:"versions"` +} + +type GrypeArtifact struct { + Name string `json:"name"` + Version string `json:"version"` + Type string `json:"type"` +} diff --git a/pkg/scanner/npmaudit/npmaudit.go b/pkg/scanner/npmaudit/npmaudit.go new file mode 100644 index 0000000..220d8a7 --- /dev/null +++ b/pkg/scanner/npmaudit/npmaudit.go @@ -0,0 +1,234 @@ +package npmaudit + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/lukaszraczylo/gohoarder/pkg/config" + "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/uuid" + "github.com/rs/zerolog/log" +) + +// ScannerName is the name of this scanner +const ScannerName = "npm-audit" + +// Scanner implements the npm audit vulnerability scanner +type Scanner struct { + config config.NpmAuditConfig +} + +// New creates a new npm audit scanner +func New(cfg config.NpmAuditConfig) *Scanner { + return &Scanner{ + config: cfg, + } +} + +// Name returns the scanner name +func (s *Scanner) Name() string { + return ScannerName +} + +// Scan scans an npm package using npm audit +func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) { + // Only scan npm packages + if registry != "npm" { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{ + "skipped": "npm-audit only supports npm packages", + }, + }, nil + } + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Str("version", version). + Msg("Starting npm audit scan") + + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "npm-audit-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + // Extract the .tgz file + if err := s.extractTgz(filePath, tmpDir); err != nil { + return nil, fmt.Errorf("failed to extract tgz: %w", err) + } + + // Find the package directory (usually "package/") + packageDir := filepath.Join(tmpDir, "package") + if _, err := os.Stat(packageDir); os.IsNotExist(err) { + // Try the tmpDir itself + packageDir = tmpDir + } + + // Run npm audit + cmd := exec.CommandContext(ctx, "npm", "audit", "--json", "--package-lock-only") + cmd.Dir = packageDir + output, _ := cmd.CombinedOutput() // npm audit returns non-zero when vulns found + + // Parse npm audit output + var auditResult NpmAuditResult + if len(output) > 0 { + if err := json.Unmarshal(output, &auditResult); err != nil { + log.Warn().Err(err).Msg("Failed to parse npm audit output") + // Return clean result on parse error + return s.emptyResult(registry, packageName, version), nil + } + } + + // Convert to our format + result := s.convertResult(&auditResult, registry, packageName, version) + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Int("vulnerabilities", result.VulnerabilityCount). + Msg("npm audit scan completed") + + return result, nil +} + +// Health checks if npm is available +func (s *Scanner) Health(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "npm", "--version") + if err := cmd.Run(); err != nil { + return fmt.Errorf("npm not available: %w", err) + } + return nil +} + +// extractTgz extracts a .tgz file +func (s *Scanner) extractTgz(tgzPath, destDir string) error { + cmd := exec.Command("tar", "-xzf", tgzPath, "-C", destDir) + return cmd.Run() +} + +// emptyResult returns an empty scan result +func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{}, + } +} + +// convertResult converts npm audit output to our ScanResult format +func (s *Scanner) convertResult(auditResult *NpmAuditResult, registry, packageName, version string) *metadata.ScanResult { + vulnerabilities := make([]metadata.Vulnerability, 0) + severityCounts := make(map[string]int) + + // Process vulnerabilities from the audit result + for _, vuln := range auditResult.Vulnerabilities { + // Normalize severity + normalizedSeverity := metadata.NormalizeSeverity(vuln.Severity) + severityCounts[normalizedSeverity]++ + + // Get references + refs := make([]string, 0) + if vuln.URL != "" { + refs = append(refs, vuln.URL) + } + for _, ref := range vuln.References { + if ref.URL != "" { + refs = append(refs, ref.URL) + } + } + + // Get fixed version + fixedIn := "" + if vuln.FixAvailable != nil { + fixedIn = fmt.Sprintf("%v", vuln.FixAvailable) + } + + vulnerabilities = append(vulnerabilities, metadata.Vulnerability{ + ID: vuln.Via, + Severity: normalizedSeverity, + Title: vuln.Name, + Description: vuln.Name, + References: refs, + FixedIn: fixedIn, + }) + } + + status := metadata.ScanStatusClean + if len(vulnerabilities) > 0 { + status = metadata.ScanStatusVulnerable + } + + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: status, + VulnerabilityCount: len(vulnerabilities), + Vulnerabilities: vulnerabilities, + Details: map[string]interface{}{ + "severity_counts": severityCounts, + }, + } +} + +// NpmAuditResult represents npm audit JSON output +type NpmAuditResult struct { + AuditReportVersion int `json:"auditReportVersion"` + Vulnerabilities map[string]NpmVulnerability `json:"vulnerabilities"` + Metadata NpmAuditMetadata `json:"metadata"` +} + +type NpmVulnerability struct { + Name string `json:"name"` + Severity string `json:"severity"` + Via string `json:"via"` + Effects []string `json:"effects"` + Range string `json:"range"` + FixAvailable interface{} `json:"fixAvailable"` + URL string `json:"url"` + References []NpmReference `json:"references"` +} + +type NpmReference struct { + URL string `json:"url"` +} + +type NpmAuditMetadata struct { + Vulnerabilities NpmVulnCounts `json:"vulnerabilities"` + Dependencies int `json:"dependencies"` +} + +type NpmVulnCounts struct { + Info int `json:"info"` + Low int `json:"low"` + Moderate int `json:"moderate"` + High int `json:"high"` + Critical int `json:"critical"` + Total int `json:"total"` +} diff --git a/pkg/scanner/pipaudit/pipaudit.go b/pkg/scanner/pipaudit/pipaudit.go new file mode 100644 index 0000000..4220900 --- /dev/null +++ b/pkg/scanner/pipaudit/pipaudit.go @@ -0,0 +1,209 @@ +package pipaudit + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/lukaszraczylo/gohoarder/pkg/config" + "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/uuid" + "github.com/rs/zerolog/log" +) + +// ScannerName is the name of this scanner +const ScannerName = "pip-audit" + +// Scanner implements the pip-audit vulnerability scanner +type Scanner struct { + config config.PipAuditConfig +} + +// New creates a new pip-audit scanner +func New(cfg config.PipAuditConfig) *Scanner { + return &Scanner{ + config: cfg, + } +} + +// Name returns the scanner name +func (s *Scanner) Name() string { + return ScannerName +} + +// Scan scans a Python package using pip-audit +func (s *Scanner) Scan(ctx context.Context, registry, packageName, version string, filePath string) (*metadata.ScanResult, error) { + // Only scan PyPI packages + if registry != "pypi" { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{ + "skipped": "pip-audit only supports PyPI packages", + }, + }, nil + } + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Str("version", version). + Msg("Starting pip-audit scan") + + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "pip-audit-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + // Copy the wheel/tar.gz file to temp directory + tmpFile := filepath.Join(tmpDir, filepath.Base(filePath)) + if err := s.copyFile(filePath, tmpFile); err != nil { + return nil, fmt.Errorf("failed to copy file: %w", err) + } + + // Run pip-audit on the package file + cmd := exec.CommandContext(ctx, "pip-audit", "-r", tmpFile, "--format", "json") + output, _ := cmd.CombinedOutput() // pip-audit returns non-zero when vulns found + + // Parse pip-audit output + var auditResult PipAuditResult + if len(output) > 0 { + if err := json.Unmarshal(output, &auditResult); err != nil { + log.Warn().Err(err).Msg("Failed to parse pip-audit output") + return s.emptyResult(registry, packageName, version), nil + } + } + + // Convert to our format + result := s.convertResult(&auditResult, registry, packageName, version) + + log.Info(). + Str("scanner", ScannerName). + Str("package", packageName). + Int("vulnerabilities", result.VulnerabilityCount). + Msg("pip-audit scan completed") + + return result, nil +} + +// Health checks if pip-audit is available +func (s *Scanner) Health(ctx context.Context) error { + cmd := exec.CommandContext(ctx, "pip-audit", "--version") + if err := cmd.Run(); err != nil { + return fmt.Errorf("pip-audit not available: %w (install with: pip install pip-audit)", err) + } + return nil +} + +// copyFile copies a file from src to dst +func (s *Scanner) copyFile(src, dst string) error { + input, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, input, 0644) +} + +// emptyResult returns an empty scan result +func (s *Scanner) emptyResult(registry, packageName, version string) *metadata.ScanResult { + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: metadata.ScanStatusClean, + VulnerabilityCount: 0, + Vulnerabilities: []metadata.Vulnerability{}, + Details: map[string]interface{}{}, + } +} + +// convertResult converts pip-audit output to our ScanResult format +func (s *Scanner) convertResult(auditResult *PipAuditResult, registry, packageName, version string) *metadata.ScanResult { + vulnerabilities := make([]metadata.Vulnerability, 0) + severityCounts := make(map[string]int) + + for _, dep := range auditResult.Dependencies { + for _, vuln := range dep.Vulns { + // Map pip-audit severity to our standard + severity := s.mapSeverity(vuln.ID) + normalizedSeverity := metadata.NormalizeSeverity(severity) + severityCounts[normalizedSeverity]++ + + // Get fixed versions + fixedIn := "" + if len(vuln.FixVersions) > 0 { + fixedIn = vuln.FixVersions[0] + } + + vulnerabilities = append(vulnerabilities, metadata.Vulnerability{ + ID: vuln.ID, + Severity: normalizedSeverity, + Title: vuln.ID, + Description: vuln.Description, + References: []string{fmt.Sprintf("https://osv.dev/vulnerability/%s", vuln.ID)}, + FixedIn: fixedIn, + }) + } + } + + status := metadata.ScanStatusClean + if len(vulnerabilities) > 0 { + status = metadata.ScanStatusVulnerable + } + + return &metadata.ScanResult{ + ID: uuid.New().String(), + Registry: registry, + PackageName: packageName, + PackageVersion: version, + Scanner: ScannerName, + ScannedAt: time.Now(), + Status: status, + VulnerabilityCount: len(vulnerabilities), + Vulnerabilities: vulnerabilities, + Details: map[string]interface{}{ + "severity_counts": severityCounts, + }, + } +} + +// mapSeverity maps vulnerability ID patterns to severity levels +func (s *Scanner) mapSeverity(vulnID string) string { + // pip-audit doesn't provide severity directly + // Default to MODERATE for all findings + return "MODERATE" +} + +// PipAuditResult represents pip-audit JSON output +type PipAuditResult struct { + Dependencies []PipDependency `json:"dependencies"` +} + +type PipDependency struct { + Name string `json:"name"` + Version string `json:"version"` + Vulns []PipVuln `json:"vulns"` +} + +type PipVuln struct { + ID string `json:"id"` + Description string `json:"description"` + FixVersions []string `json:"fix_versions"` + Aliases []string `json:"aliases"` +} diff --git a/pkg/scanner/rescanner.go b/pkg/scanner/rescanner.go index 316b677..958696c 100644 --- a/pkg/scanner/rescanner.go +++ b/pkg/scanner/rescanner.go @@ -104,6 +104,11 @@ func (w *RescanWorker) rescanPackages(ctx context.Context) { } if !needsRescan { + log.Debug(). + Str("package", pkg.Name). + Str("version", pkg.Version). + Bool("security_scanned", pkg.SecurityScanned). + Msg("Package does not need rescanning, skipping") skipped++ continue } diff --git a/pkg/scanner/scanner.go b/pkg/scanner/scanner.go index 0c0b994..6c97617 100644 --- a/pkg/scanner/scanner.go +++ b/pkg/scanner/scanner.go @@ -7,7 +7,12 @@ import ( "github.com/lukaszraczylo/gohoarder/pkg/config" "github.com/lukaszraczylo/gohoarder/pkg/metadata" + "github.com/lukaszraczylo/gohoarder/pkg/scanner/ghsa" + "github.com/lukaszraczylo/gohoarder/pkg/scanner/govulncheck" + "github.com/lukaszraczylo/gohoarder/pkg/scanner/grype" + "github.com/lukaszraczylo/gohoarder/pkg/scanner/npmaudit" "github.com/lukaszraczylo/gohoarder/pkg/scanner/osv" + "github.com/lukaszraczylo/gohoarder/pkg/scanner/pipaudit" "github.com/lukaszraczylo/gohoarder/pkg/scanner/trivy" "github.com/rs/zerolog/log" ) @@ -72,6 +77,48 @@ func New(cfg config.SecurityConfig, metadataStore metadata.MetadataStore) (*Mana log.Info().Msg("OSV scanner enabled") } + // Initialize Grype scanner + if cfg.Scanners.Grype.Enabled { + grypeScanner := grype.New(cfg.Scanners.Grype) + manager.RegisterScanner(grypeScanner) + log.Info().Msg("Grype scanner enabled") + + // Update database on startup if configured + if cfg.UpdateDBOnStartup { + if err := grypeScanner.UpdateDatabase(context.Background()); err != nil { + log.Warn().Err(err).Msg("Failed to update Grype database on startup") + } + } + } + + // Initialize govulncheck scanner + if cfg.Scanners.Govulncheck.Enabled { + govulncheckScanner := govulncheck.New(cfg.Scanners.Govulncheck) + manager.RegisterScanner(govulncheckScanner) + log.Info().Msg("govulncheck scanner enabled") + } + + // Initialize npm-audit scanner + if cfg.Scanners.NpmAudit.Enabled { + npmAuditScanner := npmaudit.New(cfg.Scanners.NpmAudit) + manager.RegisterScanner(npmAuditScanner) + log.Info().Msg("npm-audit scanner enabled") + } + + // Initialize pip-audit scanner + if cfg.Scanners.PipAudit.Enabled { + pipAuditScanner := pipaudit.New(cfg.Scanners.PipAudit) + manager.RegisterScanner(pipAuditScanner) + log.Info().Msg("pip-audit scanner enabled") + } + + // Initialize GitHub Advisory Database scanner + if cfg.Scanners.GHSA.Enabled { + ghsaScanner := ghsa.New(cfg.Scanners.GHSA) + manager.RegisterScanner(ghsaScanner) + log.Info().Msg("GitHub Advisory Database scanner enabled") + } + if len(manager.scanners) == 0 { log.Warn().Msg("Security scanning enabled but no scanners configured") } @@ -101,6 +148,15 @@ func (m *Manager) ScanPackage(ctx context.Context, registry, packageName, versio scannerNames := make([]string, 0) for _, scanner := range m.scanners { + // Skip scanners that don't support this registry + if !m.shouldRunScanner(scanner.Name(), registry) { + log.Debug(). + Str("scanner", scanner.Name()). + Str("registry", registry). + Msg("Skipping scanner - not compatible with registry") + continue + } + result, err := scanner.Scan(ctx, registry, packageName, version, filePath) if err != nil { log.Error(). @@ -433,3 +489,27 @@ func (m *Manager) Health(ctx context.Context) error { } return nil } + +// shouldRunScanner determines if a scanner should run for a given registry +// Language-specific scanners only run for their target ecosystems +func (m *Manager) shouldRunScanner(scannerName, registry string) bool { + registry = strings.ToLower(registry) + + // Language-specific scanners - only run for their target registry + switch scannerName { + case "govulncheck": + return registry == "go" + case "npm-audit": + return registry == "npm" + case "pip-audit": + return registry == "pypi" + + // Multi-ecosystem scanners - run for all registries + case "trivy", "osv", "grype", "github-advisory-database": + return true + + // Default: allow scanner to run (for future scanners) + default: + return true + } +} diff --git a/script/test-packages.sh b/script/test-packages.sh new file mode 100755 index 0000000..5410714 --- /dev/null +++ b/script/test-packages.sh @@ -0,0 +1,153 @@ +#!/bin/bash +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Configuration +GOHOARDER_URL="${GOHOARDER_URL:-}" +TEMP_DIR="/tmp/gohoarder-test-$$" + +# Cleanup function +cleanup() { + echo "" + echo "Cleaning up temporary directories..." + rm -rf "$TEMP_DIR" +} +trap cleanup EXIT + +# Auto-detect gohoarder URL if not set +if [ -z "$GOHOARDER_URL" ]; then + # Try to read port from config.yaml + if [ -f "config.yaml" ]; then + PORT=$(grep "^ port:" config.yaml | awk '{print $2}') + if [ -n "$PORT" ]; then + GOHOARDER_URL="http://localhost:$PORT" + fi + fi + + # Fallback to default + if [ -z "$GOHOARDER_URL" ]; then + GOHOARDER_URL="http://localhost:8080" + fi +fi + +echo "=========================================" +echo "Downloading test packages through gohoarder" +echo "GoHoarder URL: $GOHOARDER_URL" +echo "=========================================" +echo "" + +# Check if gohoarder is running +if ! curl -s -f "$GOHOARDER_URL/api/stats" > /dev/null 2>&1; then + echo -e "${RED}ERROR: gohoarder is not running at $GOHOARDER_URL${NC}" + echo "" + echo "Please start gohoarder first with: make run" + echo "" + echo "If gohoarder is running on a different port, set GOHOARDER_URL:" + echo " GOHOARDER_URL=http://localhost:9090 make test-packages" + exit 1 +fi + +echo -e "${GREEN}✓ gohoarder is running${NC}" +echo "" + +# Create temp directories +mkdir -p "$TEMP_DIR/npm" "$TEMP_DIR/pypi" "$TEMP_DIR/go" + +# +# npm packages +# +echo -e "${YELLOW}Testing npm packages...${NC}" + +npm_packages=( + "axios@0.21.1:has vulnerabilities (SSRF, ReDoS)" + "lodash@4.17.15:has vulnerabilities (prototype pollution)" + "express@4.17.1:has vulnerabilities (open redirect)" + "react@18.2.0:clean package" +) + +for pkg_info in "${npm_packages[@]}"; do + IFS=':' read -r pkg desc <<< "$pkg_info" + IFS='@' read -r pkg_name pkg_version <<< "$pkg" + echo -n " • $pkg ($desc)... " + + # Download tarball directly to ensure it goes through proxy + # npm/pnpm may use local cache and bypass the proxy + tarball_filename="${pkg_name##*/}-${pkg_version}.tgz" + tarball_url="$GOHOARDER_URL/npm/$pkg_name/-/$tarball_filename" + + if curl -f -s "$tarball_url" -o "$TEMP_DIR/npm/$tarball_filename" > /dev/null 2>&1; then + echo -e "${GREEN}✓${NC}" + else + echo -e "${RED}✗${NC}" + fi +done + +echo "" + +# +# PyPI packages +# +echo -e "${YELLOW}Testing PyPI packages...${NC}" + +pypi_packages=( + "requests==2.25.0:older version, may have vulnerabilities" + "django==2.2.0:old version with known security issues" + "flask==0.12.0:old version with XSS vulnerabilities" + "certifi==2023.7.22:clean package" +) + +for pkg_info in "${pypi_packages[@]}"; do + IFS=':' read -r pkg desc <<< "$pkg_info" + echo -n " • $pkg ($desc)... " + if pip install --index-url "$GOHOARDER_URL/pypi/simple/" \ + --trusted-host localhost \ + "$pkg" \ + --target "$TEMP_DIR/pypi" \ + --quiet > /dev/null 2>&1; then + echo -e "${GREEN}✓${NC}" + else + echo -e "${RED}✗${NC}" + fi +done + +echo "" + +# +# Go packages +# +echo -e "${YELLOW}Testing Go packages...${NC}" + +cd "$TEMP_DIR/go" +go mod init test > /dev/null 2>&1 + +go_packages=( + "github.com/gin-gonic/gin@v1.7.0:may have vulnerabilities" + "github.com/dgrijalva/jwt-go@v3.2.0:known JWT signing vulnerabilities" + "golang.org/x/crypto@v0.0.0-20200622213623-75b288015ac9:old version" + "github.com/google/uuid@v1.6.0:clean package" +) + +for pkg_info in "${go_packages[@]}"; do + IFS=':' read -r pkg desc <<< "$pkg_info" + echo -n " • $pkg ($desc)... " + if GOPROXY="$GOHOARDER_URL/go,direct" go get "$pkg" > /dev/null 2>&1; then + echo -e "${GREEN}✓${NC}" + else + echo -e "${RED}✗${NC}" + fi +done + +echo "" +echo "=========================================" +echo -e "${GREEN}Test package downloads complete!${NC}" +echo "" +echo "Next steps:" +echo " • Visit $GOHOARDER_URL to view packages" +echo " • Check vulnerability scan results" +echo " • Compare clean vs vulnerable packages" +echo "========================================="