mirror of
https://github.com/lukaszraczylo/gohoarder.git
synced 2026-06-10 23:29:22 +00:00
fixes
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
)
|
||||
|
||||
// BaseHandler provides common functionality for all proxy handlers
|
||||
type BaseHandler struct {
|
||||
Cache *cache.Manager
|
||||
Client *network.Client
|
||||
Upstream string
|
||||
Registry string
|
||||
}
|
||||
|
||||
// Config holds common proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream registry URL (e.g., registry.npmjs.org)
|
||||
}
|
||||
|
||||
// GetRegistry returns the registry type
|
||||
func (h *BaseHandler) GetRegistry() string {
|
||||
return h.Registry
|
||||
}
|
||||
|
||||
// NewBaseHandler creates a new base handler with common fields
|
||||
func NewBaseHandler(cache *cache.Manager, client *network.Client, registry, upstream string) *BaseHandler {
|
||||
return &BaseHandler{
|
||||
Cache: cache,
|
||||
Client: client,
|
||||
Upstream: upstream,
|
||||
Registry: registry,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewBaseHandler tests base handler creation
|
||||
func TestNewBaseHandler(t *testing.T) {
|
||||
// Use nil for cache and client since we're only testing structure
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
require.NotNil(t, handler)
|
||||
assert.Equal(t, "npm", handler.Registry)
|
||||
assert.Equal(t, "https://registry.npmjs.org", handler.Upstream)
|
||||
assert.Nil(t, handler.Cache)
|
||||
assert.Nil(t, handler.Client)
|
||||
}
|
||||
|
||||
// TestGetRegistry tests registry type retrieval
|
||||
func TestGetRegistry(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
}{
|
||||
{"npm registry", "npm"},
|
||||
{"pypi registry", "pypi"},
|
||||
{"go registry", "go"},
|
||||
{"custom registry", "custom"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := &BaseHandler{Registry: tt.registry}
|
||||
assert.Equal(t, tt.registry, handler.GetRegistry())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError tests upstream error handling
|
||||
func TestHandleUpstreamError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
url string
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
// GOOD: Standard error
|
||||
{
|
||||
name: "connection error",
|
||||
err: errors.New("connection refused"),
|
||||
url: "https://registry.npmjs.org/react",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
// WRONG: Timeout error
|
||||
{
|
||||
name: "timeout error",
|
||||
err: context.DeadlineExceeded,
|
||||
url: "https://registry.npmjs.org/lodash",
|
||||
context: "metadata",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch metadata",
|
||||
},
|
||||
// EDGE: Empty context
|
||||
{
|
||||
name: "empty context",
|
||||
err: errors.New("error"),
|
||||
url: "https://example.com",
|
||||
context: "",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch",
|
||||
},
|
||||
// EDGE: Long URL
|
||||
{
|
||||
name: "long URL",
|
||||
err: errors.New("error"),
|
||||
url: "https://registry.npmjs.org/@scope/very-long-package-name/versions/1.2.3",
|
||||
context: "package",
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantContain: "Failed to fetch package",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleUpstreamError(w, tt.err, tt.url, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckUpstreamStatus tests upstream status validation
|
||||
func TestCheckUpstreamStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body io.ReadCloser
|
||||
wantErr bool
|
||||
errContains string
|
||||
bodyClosed bool
|
||||
}{
|
||||
// GOOD: OK status
|
||||
{
|
||||
name: "200 OK",
|
||||
statusCode: http.StatusOK,
|
||||
body: io.NopCloser(strings.NewReader("success")),
|
||||
wantErr: false,
|
||||
},
|
||||
// WRONG: Not found
|
||||
{
|
||||
name: "404 Not Found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: io.NopCloser(strings.NewReader("not found")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// WRONG: Server error
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: io.NopCloser(strings.NewReader("error")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 500",
|
||||
},
|
||||
// BAD: Unauthorized
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
body: io.NopCloser(strings.NewReader("unauthorized")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 401",
|
||||
},
|
||||
// EDGE: Nil body
|
||||
{
|
||||
name: "nil body with error",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: nil,
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 404",
|
||||
},
|
||||
// EDGE: Redirect status
|
||||
{
|
||||
name: "302 Found",
|
||||
statusCode: http.StatusFound,
|
||||
body: io.NopCloser(strings.NewReader("redirect")),
|
||||
wantErr: true,
|
||||
errContains: "upstream returned status 302",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CheckUpstreamStatus(tt.statusCode, tt.body)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInvalidRequest tests invalid request handling
|
||||
func TestHandleInvalidRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
registry string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "npm invalid request",
|
||||
registry: "npm",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid npm request",
|
||||
},
|
||||
{
|
||||
name: "pypi invalid request",
|
||||
registry: "pypi",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid pypi request",
|
||||
},
|
||||
{
|
||||
name: "go invalid request",
|
||||
registry: "go",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantContain: "Invalid go request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInvalidRequest(w, tt.registry)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleInternalError tests internal error handling
|
||||
func TestHandleInternalError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
context string
|
||||
wantStatus int
|
||||
wantContain string
|
||||
}{
|
||||
{
|
||||
name: "database error",
|
||||
err: errors.New("database connection failed"),
|
||||
context: "database",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: database",
|
||||
},
|
||||
{
|
||||
name: "cache error",
|
||||
err: errors.New("cache write failed"),
|
||||
context: "cache",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantContain: "Internal error: cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
HandleInternalError(w, tt.err, tt.context)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.wantContain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: FetchFromUpstream tests would require mocking cache.Manager and network.Client
|
||||
// which requires concrete implementations. Integration tests cover this functionality.
|
||||
|
||||
// TestWriteResponse tests HTTP response writing
|
||||
func TestWriteResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
contentType string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantErr bool
|
||||
}{
|
||||
// GOOD: Write tarball
|
||||
{
|
||||
name: "write tarball",
|
||||
data: "package data here",
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "package data here",
|
||||
wantErr: false,
|
||||
},
|
||||
// GOOD: Write JSON
|
||||
{
|
||||
name: "write JSON metadata",
|
||||
data: `{"name":"react","version":"18.2.0"}`,
|
||||
contentType: "application/json",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: `{"name":"react","version":"18.2.0"}`,
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Empty data
|
||||
{
|
||||
name: "empty data",
|
||||
data: "",
|
||||
contentType: "text/plain",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "",
|
||||
wantErr: false,
|
||||
},
|
||||
// EDGE: Large data
|
||||
{
|
||||
name: "large data",
|
||||
data: strings.Repeat("x", 100000),
|
||||
contentType: "application/octet-stream",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: strings.Repeat("x", 100000),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(tt.data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, tt.contentType)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.contentType, w.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseHandlerFields tests that BaseHandler fields are properly set
|
||||
func TestBaseHandlerFields(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
expected interface{}
|
||||
}{
|
||||
{"registry field", "registry", "npm"},
|
||||
{"upstream field", "upstream", "https://registry.npmjs.org"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
switch tt.field {
|
||||
case "registry":
|
||||
assert.Equal(t, tt.expected, handler.Registry)
|
||||
case "upstream":
|
||||
assert.Equal(t, tt.expected, handler.Upstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyHandlerInterface tests that BaseHandler can be used as ProxyHandler
|
||||
func TestProxyHandlerInterface(t *testing.T) {
|
||||
handler := NewBaseHandler(nil, nil, "npm", "https://registry.npmjs.org")
|
||||
|
||||
// Verify GetRegistry works
|
||||
registry := handler.GetRegistry()
|
||||
assert.Equal(t, "npm", registry)
|
||||
}
|
||||
|
||||
// TestConcurrentWriteResponse tests that WriteResponse is safe for concurrent use
|
||||
func TestConcurrentWriteResponse(t *testing.T) {
|
||||
const numGoroutines = 10
|
||||
|
||||
errs := make(chan error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(n int) {
|
||||
w := httptest.NewRecorder()
|
||||
data := strings.Repeat("x", 1000)
|
||||
entry := &cache.CacheEntry{
|
||||
Data: io.NopCloser(bytes.NewReader([]byte(data))),
|
||||
}
|
||||
|
||||
err := WriteResponse(w, entry, "text/plain")
|
||||
errs <- err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Collect results
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errs
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// HandleUpstreamError logs an error and sends an HTTP 502 Bad Gateway response
|
||||
// This is the common pattern used across all proxy handlers when upstream fetch fails
|
||||
func HandleUpstreamError(w http.ResponseWriter, err error, url, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", url).
|
||||
Str("context", context).
|
||||
Msg("Failed to fetch from upstream")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Failed to fetch %s", context), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// CheckUpstreamStatus validates HTTP status code from upstream
|
||||
// Returns error if status is not OK, closing body if needed
|
||||
func CheckUpstreamStatus(statusCode int, body io.ReadCloser) error {
|
||||
if statusCode != http.StatusOK {
|
||||
if body != nil {
|
||||
body.Close()
|
||||
}
|
||||
return fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleInvalidRequest sends a 400 Bad Request response for invalid proxy requests
|
||||
func HandleInvalidRequest(w http.ResponseWriter, registry string) {
|
||||
http.Error(w, fmt.Sprintf("Invalid %s request", registry), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// HandleInternalError logs an internal error and sends 500 response
|
||||
func HandleInternalError(w http.ResponseWriter, err error, context string) {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("context", context).
|
||||
Msg("Internal error processing request")
|
||||
|
||||
http.Error(w, fmt.Sprintf("Internal error: %s", context), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FetchFromUpstream is a common helper to fetch content from upstream with caching
|
||||
// This encapsulates the common pattern of: cache.Get -> network.Get -> error handling
|
||||
func FetchFromUpstream(
|
||||
ctx context.Context,
|
||||
cacheManager *cache.Manager,
|
||||
client *network.Client,
|
||||
registry, name, version, upstreamURL string,
|
||||
) (*cache.CacheEntry, error) {
|
||||
entry, err := cacheManager.Get(ctx, registry, name, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := client.Get(ctx, upstreamURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := CheckUpstreamStatus(statusCode, body); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return body, upstreamURL, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("url", upstreamURL).
|
||||
Str("registry", registry).
|
||||
Str("name", name).
|
||||
Str("version", version).
|
||||
Msg("Failed to fetch package from upstream")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// WriteResponse writes the cache entry data to the HTTP response writer
|
||||
// Sets appropriate content type and handles errors
|
||||
func WriteResponse(w http.ResponseWriter, entry *cache.CacheEntry, contentType string) error {
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
if _, err := io.Copy(w, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to write response")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProxyHandler defines the common interface for all registry proxies
|
||||
type ProxyHandler interface {
|
||||
http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// GetRegistry returns the registry type (npm, pypi, go)
|
||||
GetRegistry() string
|
||||
|
||||
// Health checks if the proxy can reach its upstream
|
||||
Health(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Stats represents proxy statistics
|
||||
type Stats struct {
|
||||
Registry string
|
||||
TotalRequests int64
|
||||
CacheHits int64
|
||||
CacheMisses int64
|
||||
UpstreamErrors int64
|
||||
AvgResponseTime time.Duration
|
||||
LastUpdated time.Time
|
||||
}
|
||||
@@ -0,0 +1,290 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the GOPROXY protocol
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
sumDBURL string
|
||||
}
|
||||
|
||||
// Config holds Go proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream Go proxy (e.g., proxy.golang.org)
|
||||
SumDBURL string // Checksum database URL
|
||||
}
|
||||
|
||||
// New creates a new Go proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://proxy.golang.org"
|
||||
}
|
||||
|
||||
if config.SumDBURL == "" {
|
||||
config.SumDBURL = "https://sum.golang.org"
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
sumDBURL: config.SumDBURL,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles GOPROXY protocol requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// Path is already stripped by http.StripPrefix in app.go
|
||||
path := r.URL.Path
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Msg("Processing Go proxy request")
|
||||
|
||||
// Parse GOPROXY request
|
||||
// Formats:
|
||||
// /@v/list - list versions
|
||||
// /@v/$version.info - version info
|
||||
// /@v/$version.mod - go.mod file
|
||||
// /@v/$version.zip - module zip
|
||||
// /@latest - latest version
|
||||
|
||||
log.Debug().Str("path", path).Msg("Go proxy request")
|
||||
|
||||
// Route request based on path
|
||||
if strings.HasPrefix(path, "/sumdb/") {
|
||||
h.handleSumDB(ctx, w, r, path)
|
||||
} else if strings.HasSuffix(path, "/@v/list") {
|
||||
h.handleList(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".info") {
|
||||
h.handleInfo(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".mod") {
|
||||
h.handleMod(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/@v/") && strings.HasSuffix(path, ".zip") {
|
||||
h.handleZip(ctx, w, r, path)
|
||||
} else if strings.HasSuffix(path, "/@latest") {
|
||||
h.handleLatest(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid Go proxy request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleList handles /@v/list requests
|
||||
func (h *Handler) handleList(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", modulePath, "list", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version list")
|
||||
http.Error(w, "Failed to fetch version list", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleInfo handles /@v/$version.info requests
|
||||
func (h *Handler) handleInfo(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".info")
|
||||
// Use .info suffix to distinguish from .mod and .zip in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".info"
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch version info")
|
||||
http.Error(w, "Failed to fetch version info", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleMod handles /@v/$version.mod requests
|
||||
func (h *Handler) handleMod(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".mod")
|
||||
// Use .mod suffix to distinguish from .info and .zip in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".mod"
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch go.mod")
|
||||
http.Error(w, "Failed to fetch go.mod", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleZip handles /@v/$version.zip requests
|
||||
func (h *Handler) handleZip(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
version := h.extractVersion(path, ".zip")
|
||||
// Use .zip suffix to distinguish from .info and .mod in cache
|
||||
cacheKey := modulePath + "/@v/" + version + ".zip"
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", cacheKey, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch module zip")
|
||||
http.Error(w, "Failed to fetch module zip", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleLatest handles /@latest requests
|
||||
func (h *Handler) handleLatest(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
modulePath := h.extractModulePath(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "go", modulePath, "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch latest version")
|
||||
http.Error(w, "Failed to fetch latest version", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleSumDB handles sumdb requests (checksum database)
|
||||
func (h *Handler) handleSumDB(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
// path format: /sumdb/sum.golang.org/...
|
||||
// Remove /sumdb/ prefix and proxy to sumdb URL
|
||||
sumdbPath := strings.TrimPrefix(path, "/sumdb/sum.golang.org")
|
||||
url := h.sumDBURL + sumdbPath
|
||||
|
||||
log.Debug().Str("url", url).Msg("Proxying sumdb request")
|
||||
|
||||
// Sumdb requests should not be cached, proxy directly
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch from sumdb")
|
||||
http.Error(w, "Failed to fetch from sumdb", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
log.Error().Int("status", statusCode).Str("url", url).Msg("Sumdb returned non-OK status")
|
||||
http.Error(w, "Sumdb error", statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
|
||||
io.Copy(w, body)
|
||||
}
|
||||
|
||||
// extractVersion extracts version from path
|
||||
func (h *Handler) extractVersion(path, suffix string) string {
|
||||
// path format: /module/path/@v/v1.2.3.suffix
|
||||
parts := strings.Split(path, "/@v/")
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSuffix(parts[1], suffix)
|
||||
}
|
||||
|
||||
// extractModulePath extracts the clean module path from a GOPROXY path
|
||||
// Examples:
|
||||
//
|
||||
// /github.com/avast/retry-go/v4/@v/v4.6.1.zip -> github.com/avast/retry-go/v4
|
||||
// /golang.org/x/net/@v/v0.40.0.mod -> golang.org/x/net
|
||||
// /github.com/user/repo/@v/list -> github.com/user/repo
|
||||
func (h *Handler) extractModulePath(path string) string {
|
||||
// Remove leading slash
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
// Split on /@v/ to get the module path
|
||||
parts := strings.Split(path, "/@v/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// Fallback: remove /@latest suffix if present
|
||||
return strings.TrimSuffix(path, "/@latest")
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package npm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the NPM registry protocol
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
}
|
||||
|
||||
// Config holds NPM proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream NPM registry (e.g., registry.npmjs.org)
|
||||
}
|
||||
|
||||
// New creates a new NPM proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://registry.npmjs.org"
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles NPM registry requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
path := strings.TrimPrefix(r.URL.Path, "/npm")
|
||||
|
||||
log.Debug().Str("path", path).Str("method", r.Method).Msg("NPM proxy request")
|
||||
|
||||
// Handle different NPM request types
|
||||
// Check for tarballs FIRST before special endpoints (tarballs also contain "/-/")
|
||||
if isTarballRequest(path) {
|
||||
// Package tarball: /@scope/package/-/package-version.tgz
|
||||
h.handleTarball(ctx, w, r, path)
|
||||
} else if strings.Contains(path, "/-/") {
|
||||
// Special NPM endpoints (e.g., /-/ping, /-/user/token)
|
||||
h.handleSpecial(ctx, w, r, path)
|
||||
} else if isPackageMetadata(path) {
|
||||
// Package metadata: /@scope/package or /package
|
||||
h.handleMetadata(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid NPM request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMetadata handles package metadata requests
|
||||
func (h *Handler) handleMetadata(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "npm", packageName, "metadata", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package metadata")
|
||||
http.Error(w, "Failed to fetch package metadata", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
// Read metadata into memory for URL rewriting
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read metadata")
|
||||
http.Error(w, "Failed to read metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON metadata
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal(buf.Bytes(), &metadata); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to parse metadata JSON")
|
||||
http.Error(w, "Failed to parse metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite tarball URLs to point to our proxy
|
||||
proxyBaseURL := getProxyBaseURL(r)
|
||||
rewriteMetadataURLs(metadata, h.upstream, proxyBaseURL)
|
||||
|
||||
// Serialize modified metadata
|
||||
modifiedJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to serialize modified metadata")
|
||||
http.Error(w, "Failed to serialize metadata", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.Write(modifiedJSON)
|
||||
}
|
||||
|
||||
// handleTarball handles package tarball requests
|
||||
func (h *Handler) handleTarball(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
packageName, version := extractTarballInfo(path)
|
||||
|
||||
// Construct proper upstream URL with /-/ format
|
||||
// Format: https://registry.npmjs.org/package/-/package-version.tgz
|
||||
tarballFilename := strings.ReplaceAll(packageName, "/", "-") + "-" + version + ".tgz"
|
||||
url := fmt.Sprintf("%s/%s/-/%s", h.upstream, packageName, tarballFilename)
|
||||
|
||||
log.Debug().
|
||||
Str("path", path).
|
||||
Str("package", packageName).
|
||||
Str("version", version).
|
||||
Str("upstream_url", url).
|
||||
Msg("Handling tarball request")
|
||||
|
||||
entry, err := h.cache.Get(ctx, "npm", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package tarball")
|
||||
http.Error(w, "Failed to fetch package tarball", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handleSpecial handles special NPM endpoints
|
||||
func (h *Handler) handleSpecial(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
|
||||
// Don't cache special endpoints, proxy directly
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch special endpoint")
|
||||
http.Error(w, "Failed to fetch from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
io.Copy(w, body)
|
||||
}
|
||||
|
||||
// isTarballRequest checks if the request is for a tarball
|
||||
func isTarballRequest(path string) bool {
|
||||
return strings.HasSuffix(path, ".tgz") || strings.HasSuffix(path, ".tar.gz")
|
||||
}
|
||||
|
||||
// isPackageMetadata checks if the request is for package metadata
|
||||
func isPackageMetadata(path string) bool {
|
||||
// Package metadata doesn't have file extensions
|
||||
return !isTarballRequest(path) && !strings.Contains(path, "/-/")
|
||||
}
|
||||
|
||||
// extractPackageName extracts package name from path
|
||||
func extractPackageName(path string) string {
|
||||
// Remove leading slash
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
|
||||
// Handle scoped packages (@scope/package)
|
||||
if strings.HasPrefix(path, "@") {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0] + "/" + parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// Regular package
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// extractTarballInfo extracts package name and version from tarball path
|
||||
func extractTarballInfo(path string) (string, string) {
|
||||
// Format: /@scope/package/-/package-version.tgz
|
||||
// or: /package/-/package-version.tgz
|
||||
// Also handle: /package/package-version.tgz (fallback)
|
||||
|
||||
// Try standard format with /-/
|
||||
parts := strings.Split(path, "/-/")
|
||||
if len(parts) == 2 {
|
||||
packageName := extractPackageName(parts[0])
|
||||
tarballName := parts[1]
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tgz")
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
|
||||
|
||||
// Remove package name prefix to get version
|
||||
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
|
||||
version := strings.TrimPrefix(tarballName, prefix)
|
||||
|
||||
return packageName, version
|
||||
}
|
||||
|
||||
// Fallback: parse path without /-/
|
||||
// Format: /package/package-version.tgz or /@scope/package/package-version.tgz
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
pathParts := strings.Split(path, "/")
|
||||
|
||||
if len(pathParts) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
var packageName, tarballName string
|
||||
|
||||
// Handle scoped packages
|
||||
if strings.HasPrefix(pathParts[0], "@") && len(pathParts) >= 3 {
|
||||
packageName = pathParts[0] + "/" + pathParts[1]
|
||||
tarballName = pathParts[len(pathParts)-1]
|
||||
} else {
|
||||
packageName = pathParts[0]
|
||||
tarballName = pathParts[len(pathParts)-1]
|
||||
}
|
||||
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tgz")
|
||||
tarballName = strings.TrimSuffix(tarballName, ".tar.gz")
|
||||
|
||||
// Remove package name prefix to get version
|
||||
prefix := strings.ReplaceAll(packageName, "/", "-") + "-"
|
||||
version := strings.TrimPrefix(tarballName, prefix)
|
||||
|
||||
return packageName, version
|
||||
}
|
||||
|
||||
// getProxyBaseURL constructs the proxy base URL from the request
|
||||
func getProxyBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
host := r.Host
|
||||
return fmt.Sprintf("%s://%s/npm", scheme, host)
|
||||
}
|
||||
|
||||
// rewriteMetadataURLs recursively rewrites upstream URLs to proxy URLs in metadata
|
||||
func rewriteMetadataURLs(data interface{}, upstream, proxyBaseURL string) {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for key, value := range v {
|
||||
if key == "tarball" || key == "dist" {
|
||||
// Rewrite tarball URL
|
||||
if strVal, ok := value.(string); ok {
|
||||
v[key] = strings.Replace(strVal, upstream, proxyBaseURL, 1)
|
||||
} else if distMap, ok := value.(map[string]interface{}); ok {
|
||||
// Handle dist object with tarball field
|
||||
rewriteMetadataURLs(distMap, upstream, proxyBaseURL)
|
||||
}
|
||||
} else {
|
||||
// Recursively process nested objects
|
||||
rewriteMetadataURLs(value, upstream, proxyBaseURL)
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range v {
|
||||
rewriteMetadataURLs(item, upstream, proxyBaseURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package pypi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/cache"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/network"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler implements the PyPI Simple API (PEP 503)
|
||||
type Handler struct {
|
||||
cache *cache.Manager
|
||||
client *network.Client
|
||||
upstream string
|
||||
}
|
||||
|
||||
// Config holds PyPI proxy configuration
|
||||
type Config struct {
|
||||
Upstream string // Upstream PyPI index (e.g., pypi.org/simple)
|
||||
}
|
||||
|
||||
// New creates a new PyPI proxy handler
|
||||
func New(cacheManager *cache.Manager, client *network.Client, config Config) *Handler {
|
||||
if config.Upstream == "" {
|
||||
config.Upstream = "https://pypi.org/simple"
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
cache: cacheManager,
|
||||
client: client,
|
||||
upstream: config.Upstream,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP handles PyPI Simple API requests
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
path := strings.TrimPrefix(r.URL.Path, "/pypi")
|
||||
|
||||
log.Debug().Str("path", path).Str("method", r.Method).Msg("PyPI proxy request")
|
||||
|
||||
// PEP 503 Simple API endpoints:
|
||||
// / - index page
|
||||
// /{package}/ - package page with links to files
|
||||
|
||||
if path == "/" || path == "" {
|
||||
// Index page
|
||||
h.handleIndex(ctx, w, r)
|
||||
} else if isPackagePage(path) {
|
||||
// Package page
|
||||
h.handlePackagePage(ctx, w, r, path)
|
||||
} else if isPackageFile(path) {
|
||||
// Package file download (wheel or sdist)
|
||||
h.handlePackageFile(ctx, w, r, path)
|
||||
} else {
|
||||
http.Error(w, "Invalid PyPI request", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handleIndex handles the index page request
|
||||
func (h *Handler) handleIndex(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
url := h.upstream + "/"
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", "index", "latest", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch PyPI index")
|
||||
http.Error(w, "Failed to fetch PyPI index", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// handlePackagePage handles package page requests
|
||||
func (h *Handler) handlePackagePage(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
url := h.upstream + path
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, "page", func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, url, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("Failed to fetch package page")
|
||||
http.Error(w, "Failed to fetch package page", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
// Read page into memory for URL rewriting
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, entry.Data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read package page")
|
||||
http.Error(w, "Failed to read package page", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite package file URLs to point to our proxy
|
||||
proxyBaseURL := getProxyBaseURL(r)
|
||||
modifiedHTML := rewritePackagePageURLs(buf.String(), packageName, proxyBaseURL)
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=UTF-8")
|
||||
w.Write([]byte(modifiedHTML))
|
||||
}
|
||||
|
||||
// handlePackageFile handles package file download requests
|
||||
func (h *Handler) handlePackageFile(ctx context.Context, w http.ResponseWriter, r *http.Request, path string) {
|
||||
packageName, version := extractPackageFileInfo(path)
|
||||
|
||||
// Check if we have the original URL from the rewritten package page
|
||||
originalURL := r.URL.Query().Get("original_url")
|
||||
|
||||
// If no original URL provided, fall back to constructing from upstream
|
||||
// (this handles direct file requests not from rewritten package pages)
|
||||
if originalURL == "" {
|
||||
originalURL = h.upstream + path
|
||||
} else {
|
||||
// Make the URL absolute if it's relative
|
||||
if !strings.HasPrefix(originalURL, "http://") && !strings.HasPrefix(originalURL, "https://") {
|
||||
originalURL = "https://pypi.org" + originalURL
|
||||
}
|
||||
}
|
||||
|
||||
entry, err := h.cache.Get(ctx, "pypi", packageName, version, func(ctx context.Context) (io.ReadCloser, string, error) {
|
||||
body, statusCode, err := h.client.Get(ctx, originalURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
body.Close()
|
||||
return nil, "", fmt.Errorf("upstream returned status %d", statusCode)
|
||||
}
|
||||
return body, originalURL, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", originalURL).Msg("Failed to fetch package file")
|
||||
http.Error(w, "Failed to fetch package file", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer entry.Data.Close()
|
||||
|
||||
// Determine content type based on file extension
|
||||
contentType := "application/octet-stream"
|
||||
if strings.HasSuffix(path, ".whl") {
|
||||
contentType = "application/zip"
|
||||
} else if strings.HasSuffix(path, ".tar.gz") {
|
||||
contentType = "application/x-gzip"
|
||||
} else if strings.HasSuffix(path, ".metadata") {
|
||||
contentType = "text/plain; charset=UTF-8"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
io.Copy(w, entry.Data)
|
||||
}
|
||||
|
||||
// isPackagePage checks if the request is for a package page
|
||||
func isPackagePage(path string) bool {
|
||||
// Package pages end with /
|
||||
return strings.HasSuffix(path, "/")
|
||||
}
|
||||
|
||||
// isPackageFile checks if the request is for a package file
|
||||
func isPackageFile(path string) bool {
|
||||
// Package files (not including .metadata files which need special handling)
|
||||
return strings.HasSuffix(path, ".whl") ||
|
||||
strings.HasSuffix(path, ".tar.gz") ||
|
||||
strings.HasSuffix(path, ".zip") ||
|
||||
strings.HasSuffix(path, ".egg")
|
||||
}
|
||||
|
||||
// extractPackageName extracts package name from path
|
||||
func extractPackageName(path string) string {
|
||||
// Remove leading and trailing slashes
|
||||
path = strings.Trim(path, "/")
|
||||
|
||||
// Remove /simple/ prefix if present
|
||||
path = strings.TrimPrefix(path, "simple/")
|
||||
|
||||
// For package pages: /package-name/
|
||||
// For files: /package-name/package-name-version.whl
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// extractPackageFileInfo extracts package name and version from file path
|
||||
func extractPackageFileInfo(path string) (string, string) {
|
||||
// Format: /package-name/package-name-version.whl
|
||||
// or: /package-name/package-name-version.tar.gz
|
||||
|
||||
packageName := extractPackageName(path)
|
||||
|
||||
// Extract filename
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 2 {
|
||||
return packageName, ""
|
||||
}
|
||||
|
||||
filename := parts[len(parts)-1]
|
||||
|
||||
// Remove extension
|
||||
filename = strings.TrimSuffix(filename, ".whl")
|
||||
filename = strings.TrimSuffix(filename, ".tar.gz")
|
||||
filename = strings.TrimSuffix(filename, ".zip")
|
||||
filename = strings.TrimSuffix(filename, ".egg")
|
||||
|
||||
// Extract version
|
||||
// Filename format: package-name-version or package_name-version
|
||||
// Version typically starts after last dash before build tags
|
||||
versionParts := strings.Split(filename, "-")
|
||||
if len(versionParts) >= 2 {
|
||||
// Simple heuristic: version is the part that starts with a digit
|
||||
for i := 1; i < len(versionParts); i++ {
|
||||
if len(versionParts[i]) > 0 && versionParts[i][0] >= '0' && versionParts[i][0] <= '9' {
|
||||
return packageName, versionParts[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return packageName, filename
|
||||
}
|
||||
|
||||
// getProxyBaseURL constructs the proxy base URL from the request
|
||||
func getProxyBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
host := r.Host
|
||||
return fmt.Sprintf("%s://%s/pypi", scheme, host)
|
||||
}
|
||||
|
||||
// rewritePackagePageURLs rewrites package file URLs in HTML to point to proxy
|
||||
func rewritePackagePageURLs(html, packageName, proxyBaseURL string) string {
|
||||
// PyPI Simple API uses href attributes in anchor tags
|
||||
// We need to rewrite URLs pointing to files.pythonhosted.org or pypi.org
|
||||
// We preserve the original URL as a query parameter so we can fetch from the correct CDN
|
||||
|
||||
// Regex pattern to match href URLs pointing to package files
|
||||
// Matches: href="https://files.pythonhosted.org/packages/.../filename.whl"
|
||||
// Also matches: href="../../packages/.../filename.whl"
|
||||
pattern := regexp.MustCompile(`href="([^"]*?(\.whl|\.tar\.gz|\.zip|\.egg)[^"]*?)"`)
|
||||
|
||||
result := pattern.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// Extract the full URL and filename
|
||||
urlPattern := regexp.MustCompile(`href="([^"]+)"`)
|
||||
urlMatch := urlPattern.FindStringSubmatch(match)
|
||||
if len(urlMatch) < 2 {
|
||||
return match
|
||||
}
|
||||
|
||||
originalURL := urlMatch[1]
|
||||
|
||||
// Extract just the filename
|
||||
filenamePattern := regexp.MustCompile(`([^/]+\.(whl|tar\.gz|zip|egg))`)
|
||||
filenameMatch := filenamePattern.FindString(originalURL)
|
||||
|
||||
if filenameMatch != "" {
|
||||
// Rewrite to proxy URL format: /pypi/package-name/filename?original_url=...
|
||||
// This preserves the original CDN URL so we can fetch from the correct location
|
||||
baseURL := strings.TrimSuffix(proxyBaseURL, "/simple")
|
||||
|
||||
// URL encode the original URL
|
||||
encodedURL := strings.ReplaceAll(originalURL, "&", "%26")
|
||||
encodedURL = strings.ReplaceAll(encodedURL, "=", "%3D")
|
||||
|
||||
newURL := fmt.Sprintf(`href="%s/%s/%s?original_url=%s"`, baseURL, packageName, filenameMatch, encodedURL)
|
||||
return newURL
|
||||
}
|
||||
|
||||
return match
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user