mirror of
https://github.com/lukaszraczylo/gohoarder.git
synced 2026-06-13 02:36:48 +00:00
fixes
This commit is contained in:
@@ -0,0 +1,415 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FilesystemStorage implements storage.StorageBackend for local filesystem
|
||||
type FilesystemStorage struct {
|
||||
basePath string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
}
|
||||
|
||||
// New creates a new filesystem storage backend
|
||||
func New(basePath string, quota int64) (*FilesystemStorage, error) {
|
||||
// Create base directory if it doesn't exist
|
||||
if err := os.MkdirAll(basePath, 0755); err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create base directory")
|
||||
}
|
||||
|
||||
fs := &FilesystemStorage{
|
||||
basePath: basePath,
|
||||
quota: quota,
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := fs.calculateUsage(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial storage usage")
|
||||
}
|
||||
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
// Get retrieves a file
|
||||
func (fs *FilesystemStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
// Check context
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("filesystem", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("filesystem", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open file")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "get", "success")
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Put stores a file atomically
|
||||
func (fs *FilesystemStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
// Check context
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
// Create directory
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create directory")
|
||||
}
|
||||
|
||||
// Create temp file for atomic write
|
||||
tempPath := path + ".tmp"
|
||||
tempFile, err := os.Create(tempPath)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create temp file")
|
||||
}
|
||||
|
||||
// Calculate checksums while writing
|
||||
// NOTE: MD5 is used for integrity verification (checksums), not cryptographic security
|
||||
md5Hash := md5.New()
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(tempFile, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
tempFile.Close()
|
||||
os.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write data")
|
||||
}
|
||||
|
||||
if err := tempFile.Close(); err != nil {
|
||||
os.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to close temp file")
|
||||
}
|
||||
|
||||
// Check quota
|
||||
fs.mu.Lock()
|
||||
if fs.quota > 0 && fs.used+written > fs.quota {
|
||||
fs.mu.Unlock()
|
||||
os.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("filesystem", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(fs.quota)
|
||||
}
|
||||
fs.used += written
|
||||
fs.mu.Unlock()
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
os.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
os.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("filesystem", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, path); err != nil {
|
||||
os.Remove(tempPath)
|
||||
fs.mu.Lock()
|
||||
fs.used -= written
|
||||
currentUsed := fs.used
|
||||
fs.mu.Unlock()
|
||||
metrics.RecordStorageOperation("filesystem", "put", "error")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename temp file")
|
||||
}
|
||||
|
||||
fs.mu.RLock()
|
||||
currentUsed := fs.used
|
||||
fs.mu.RUnlock()
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "put", "success")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file
|
||||
func (fs *FilesystemStorage) Delete(ctx context.Context, key string) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
// Get size before deletion
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "not_found")
|
||||
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
size := info.Size()
|
||||
|
||||
if err := os.Remove(path); err != nil {
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete file")
|
||||
}
|
||||
|
||||
fs.mu.Lock()
|
||||
fs.used -= size
|
||||
currentUsed := fs.used
|
||||
fs.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("filesystem", "delete", "success")
|
||||
metrics.UpdateCacheSize("filesystem", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists
|
||||
func (fs *FilesystemStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
_, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence")
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix
|
||||
func (fs *FilesystemStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
searchPath := fs.keyToPath(prefix)
|
||||
var objects []storage.StorageObject
|
||||
|
||||
err := filepath.Walk(searchPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert path back to key
|
||||
relPath, _ := filepath.Rel(fs.basePath, path)
|
||||
key := filepath.ToSlash(relPath)
|
||||
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list files")
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata
|
||||
func (fs *FilesystemStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
return &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (fs *FilesystemStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
fs.mu.RLock()
|
||||
used := fs.used
|
||||
fs.mu.RUnlock()
|
||||
|
||||
available := fs.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: fs.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks filesystem health
|
||||
func (fs *FilesystemStorage) Health(ctx context.Context) error {
|
||||
// Check if base path is accessible
|
||||
if _, err := os.Stat(fs.basePath); err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "base path not accessible")
|
||||
}
|
||||
|
||||
// Try to create a temp file (sanitize path to prevent traversal)
|
||||
tempPath := filepath.Clean(filepath.Join(fs.basePath, ".health_check"))
|
||||
f, err := os.Create(tempPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "cannot write to storage")
|
||||
}
|
||||
f.Close()
|
||||
os.Remove(tempPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (fs *FilesystemStorage) Close() error {
|
||||
// Nothing to close for filesystem
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLocalPath returns the local filesystem path for a storage key
|
||||
// This implements storage.LocalPathProvider interface
|
||||
func (fs *FilesystemStorage) GetLocalPath(ctx context.Context, key string) (string, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
path := fs.keyToPath(key)
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return "", errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat file")
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// keyToPath converts a storage key to filesystem path
|
||||
func (fs *FilesystemStorage) keyToPath(key string) string {
|
||||
// Sanitize key to prevent path traversal
|
||||
key = filepath.Clean(key)
|
||||
|
||||
// Remove any leading slashes or dots
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
|
||||
// Keep removing ../ until there are no more
|
||||
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = strings.TrimPrefix(key, "../")
|
||||
key = strings.TrimPrefix(key, "..\\")
|
||||
}
|
||||
|
||||
// Final clean and ensure it's within base path
|
||||
key = filepath.Clean(key)
|
||||
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = ""
|
||||
}
|
||||
|
||||
return filepath.Join(fs.basePath, key)
|
||||
}
|
||||
|
||||
// calculateUsage calculates current storage usage
|
||||
func (fs *FilesystemStorage) calculateUsage() error {
|
||||
var total int64
|
||||
|
||||
err := filepath.Walk(fs.basePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Skip errors
|
||||
}
|
||||
if !info.IsDir() {
|
||||
total += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fs.mu.Lock()
|
||||
fs.used = total
|
||||
fs.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("filesystem", total)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,757 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type FilesystemStorageTestSuite struct {
|
||||
suite.Suite
|
||||
tempDir string
|
||||
fs *FilesystemStorage
|
||||
}
|
||||
|
||||
func (s *FilesystemStorageTestSuite) SetupTest() {
|
||||
var err error
|
||||
s.tempDir, err = os.MkdirTemp("", "gohoarder-test-*")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.fs, err = New(s.tempDir, 1024*1024) // 1MB quota
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *FilesystemStorageTestSuite) TearDownTest() {
|
||||
if s.fs != nil {
|
||||
s.fs.Close()
|
||||
}
|
||||
if s.tempDir != "" {
|
||||
os.RemoveAll(s.tempDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemStorageTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(FilesystemStorageTestSuite))
|
||||
}
|
||||
|
||||
// Test Put operation
|
||||
func (s *FilesystemStorageTestSuite) TestPut() {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
data string
|
||||
opts *storage.PutOptions
|
||||
expectError bool
|
||||
errorCheck func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "successful put",
|
||||
key: "test/file.txt",
|
||||
data: "hello world",
|
||||
opts: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with valid MD5 checksum",
|
||||
key: "test/checksummed.txt",
|
||||
data: "test data",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "eb733a00c0c9d336e65691a37ab54293"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with invalid MD5 checksum",
|
||||
key: "test/bad-checksum.txt",
|
||||
data: "test data",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "put with nested path",
|
||||
key: "deep/nested/path/file.txt",
|
||||
data: "nested content",
|
||||
opts: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "put with path traversal attempt",
|
||||
key: "../../../etc/passwd",
|
||||
data: "malicious",
|
||||
opts: nil,
|
||||
expectError: false, // Should be sanitized, not error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
ctx := context.Background()
|
||||
reader := strings.NewReader(tt.data)
|
||||
|
||||
err := s.fs.Put(ctx, tt.key, reader, tt.opts)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
// Verify file exists
|
||||
exists, err := s.fs.Exists(ctx, tt.key)
|
||||
s.NoError(err)
|
||||
s.True(exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Get operation
|
||||
func (s *FilesystemStorageTestSuite) TestGet() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
testData := "test content for retrieval"
|
||||
err := s.fs.Put(ctx, "test/get.txt", strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expectError bool
|
||||
expectData string
|
||||
}{
|
||||
{
|
||||
name: "get existing file",
|
||||
key: "test/get.txt",
|
||||
expectError: false,
|
||||
expectData: testData,
|
||||
},
|
||||
{
|
||||
name: "get non-existent file",
|
||||
key: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
reader, err := s.fs.Get(ctx, tt.key)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
s.Nil(reader)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(reader)
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.expectData, string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Delete operation
|
||||
func (s *FilesystemStorageTestSuite) TestDelete() {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupKey string
|
||||
deleteKey string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "delete existing file",
|
||||
setupKey: "test/delete-me.txt",
|
||||
deleteKey: "test/delete-me.txt",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "delete non-existent file",
|
||||
setupKey: "",
|
||||
deleteKey: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// Setup
|
||||
if tt.setupKey != "" {
|
||||
err := s.fs.Put(ctx, tt.setupKey, strings.NewReader("to be deleted"), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
// Test delete
|
||||
err := s.fs.Delete(ctx, tt.deleteKey)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
// Verify file no longer exists
|
||||
exists, err := s.fs.Exists(ctx, tt.deleteKey)
|
||||
s.NoError(err)
|
||||
s.False(exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Exists operation
|
||||
func (s *FilesystemStorageTestSuite) TestExists() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
err := s.fs.Put(ctx, "test/exists.txt", strings.NewReader("content"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
exists bool
|
||||
}{
|
||||
{
|
||||
name: "existing file",
|
||||
key: "test/exists.txt",
|
||||
exists: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
key: "test/does-not-exist.txt",
|
||||
exists: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
exists, err := s.fs.Exists(ctx, tt.key)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.exists, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test List operation
|
||||
func (s *FilesystemStorageTestSuite) TestList() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create multiple files
|
||||
files := []string{
|
||||
"packages/npm/react/17.0.1/package.json",
|
||||
"packages/npm/react/17.0.2/package.json",
|
||||
"packages/npm/vue/3.0.0/package.json",
|
||||
"packages/pypi/django/3.2.0/wheel.whl",
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
err := s.fs.Put(ctx, file, strings.NewReader("content"), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
opts *storage.ListOptions
|
||||
expectedCount int
|
||||
expectedKeys []string
|
||||
}{
|
||||
{
|
||||
name: "list all npm packages",
|
||||
prefix: "packages/npm",
|
||||
opts: nil,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list react packages",
|
||||
prefix: "packages/npm/react",
|
||||
opts: nil,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list with pagination",
|
||||
prefix: "packages/npm",
|
||||
opts: &storage.ListOptions{MaxResults: 2, Offset: 0},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list with offset",
|
||||
prefix: "packages/npm",
|
||||
opts: &storage.ListOptions{MaxResults: 2, Offset: 1},
|
||||
expectedCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
objects, err := s.fs.List(ctx, tt.prefix, tt.opts)
|
||||
s.NoError(err)
|
||||
s.Equal(tt.expectedCount, len(objects))
|
||||
|
||||
// Verify objects have required fields
|
||||
for _, obj := range objects {
|
||||
s.NotEmpty(obj.Key)
|
||||
s.Greater(obj.Size, int64(0))
|
||||
s.False(obj.Modified.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Stat operation
|
||||
func (s *FilesystemStorageTestSuite) TestStat() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Put a test file
|
||||
testData := "stat test content"
|
||||
testKey := "test/stat.txt"
|
||||
err := s.fs.Put(ctx, testKey, strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "stat existing file",
|
||||
key: testKey,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "stat non-existent file",
|
||||
key: "does/not/exist.txt",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
info, err := s.fs.Stat(ctx, tt.key)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
s.Nil(info)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.NotNil(info)
|
||||
s.Equal(tt.key, info.Key)
|
||||
s.Equal(int64(len(testData)), info.Size)
|
||||
s.False(info.Modified.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Quota enforcement
|
||||
func (s *FilesystemStorageTestSuite) TestQuotaEnforcement() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new filesystem with small quota (100 bytes)
|
||||
smallQuotaDir, err := os.MkdirTemp("", "gohoarder-quota-*")
|
||||
s.Require().NoError(err)
|
||||
defer os.RemoveAll(smallQuotaDir)
|
||||
|
||||
smallFs, err := New(smallQuotaDir, 100)
|
||||
s.Require().NoError(err)
|
||||
defer smallFs.Close()
|
||||
|
||||
// First write should succeed
|
||||
err = smallFs.Put(ctx, "file1.txt", strings.NewReader("small content"), nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Large write should fail due to quota
|
||||
largeData := strings.Repeat("x", 200)
|
||||
err = smallFs.Put(ctx, "large.txt", strings.NewReader(largeData), nil)
|
||||
s.Error(err)
|
||||
|
||||
// Verify quota info
|
||||
quotaInfo, err := smallFs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(100), quotaInfo.Limit)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
s.LessOrEqual(quotaInfo.Used, quotaInfo.Limit)
|
||||
}
|
||||
|
||||
// Test GetQuota operation
|
||||
func (s *FilesystemStorageTestSuite) TestGetQuota() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Put some files
|
||||
err := s.fs.Put(ctx, "file1.txt", strings.NewReader("content1"), nil)
|
||||
s.Require().NoError(err)
|
||||
err = s.fs.Put(ctx, "file2.txt", strings.NewReader("content2"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
quotaInfo, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.NotNil(quotaInfo)
|
||||
s.Equal(int64(1024*1024), quotaInfo.Limit)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
s.Greater(quotaInfo.Available, int64(0))
|
||||
s.Equal(quotaInfo.Limit, quotaInfo.Used+quotaInfo.Available)
|
||||
}
|
||||
|
||||
// Test Health check
|
||||
func (s *FilesystemStorageTestSuite) TestHealth() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Healthy filesystem
|
||||
err := s.fs.Health(ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Unhealthy filesystem (removed directory)
|
||||
badDir := filepath.Join(s.tempDir, "nonexistent")
|
||||
badFs := &FilesystemStorage{basePath: badDir}
|
||||
err = badFs.Health(ctx)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
// Test Context cancellation
|
||||
func (s *FilesystemStorageTestSuite) TestContextCancellation() {
|
||||
// Create cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{
|
||||
name: "Get with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Get(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Put with cancelled context",
|
||||
fn: func() error {
|
||||
return s.fs.Put(ctx, "test.txt", strings.NewReader("data"), nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Delete with cancelled context",
|
||||
fn: func() error {
|
||||
return s.fs.Delete(ctx, "test.txt")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Exists with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Exists(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "List with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.List(ctx, "test", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Stat with cancelled context",
|
||||
fn: func() error {
|
||||
_, err := s.fs.Stat(ctx, "test.txt")
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
err := tt.fn()
|
||||
s.Error(err)
|
||||
s.Equal(context.Canceled, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test concurrent access (race condition testing)
|
||||
func (s *FilesystemStorageTestSuite) TestConcurrentAccess() {
|
||||
ctx := context.Background()
|
||||
numGoroutines := 10
|
||||
numOperations := 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
key := fmt.Sprintf("concurrent/%d/%d.txt", id, j)
|
||||
data := fmt.Sprintf("data-%d-%d", id, j)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
s.NoError(err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all files exist
|
||||
objects, err := s.fs.List(ctx, "concurrent", nil)
|
||||
s.NoError(err)
|
||||
s.Equal(numGoroutines*numOperations, len(objects))
|
||||
}
|
||||
|
||||
// Test concurrent reads and writes
|
||||
func (s *FilesystemStorageTestSuite) TestConcurrentReadsAndWrites() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup: Create some initial files
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("shared/file-%d.txt", i)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(fmt.Sprintf("initial-%d", i)), nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numReaders := 5
|
||||
numWriters := 5
|
||||
numOps := 50
|
||||
|
||||
// Concurrent readers
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOps; j++ {
|
||||
key := fmt.Sprintf("shared/file-%d.txt", j%10)
|
||||
reader, err := s.fs.Get(ctx, key)
|
||||
if err == nil {
|
||||
io.ReadAll(reader)
|
||||
reader.Close()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent writers
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOps; j++ {
|
||||
key := fmt.Sprintf("shared/writer-%d-%d.txt", id, j)
|
||||
data := fmt.Sprintf("writer-%d-%d", id, j)
|
||||
s.fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify quota tracking is consistent
|
||||
quotaInfo, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
s.Greater(quotaInfo.Used, int64(0))
|
||||
}
|
||||
|
||||
// Test Delete updates quota correctly
|
||||
func (s *FilesystemStorageTestSuite) TestDeleteUpdatesQuota() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Put a file
|
||||
testData := "test data for quota tracking"
|
||||
err := s.fs.Put(ctx, "quota/test.txt", strings.NewReader(testData), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Get quota before delete
|
||||
quotaBefore, err := s.fs.GetQuota(ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Delete the file
|
||||
err = s.fs.Delete(ctx, "quota/test.txt")
|
||||
s.NoError(err)
|
||||
|
||||
// Get quota after delete
|
||||
quotaAfter, err := s.fs.GetQuota(ctx)
|
||||
s.NoError(err)
|
||||
|
||||
// Quota should have decreased
|
||||
s.Less(quotaAfter.Used, quotaBefore.Used)
|
||||
}
|
||||
|
||||
// Test atomic write behavior
|
||||
func (s *FilesystemStorageTestSuite) TestAtomicWrite() {
|
||||
ctx := context.Background()
|
||||
key := "atomic/test.txt"
|
||||
|
||||
// Initial write
|
||||
err := s.fs.Put(ctx, key, strings.NewReader("initial"), nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Concurrent readers should never see partial writes
|
||||
var wg sync.WaitGroup
|
||||
stopReading := make(chan struct{})
|
||||
readErrors := make(chan error, 100)
|
||||
|
||||
// Start readers
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopReading:
|
||||
return
|
||||
default:
|
||||
reader, err := s.fs.Get(ctx, key)
|
||||
if err != nil {
|
||||
readErrors <- err
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
if err != nil {
|
||||
readErrors <- err
|
||||
continue
|
||||
}
|
||||
// Data should be either "initial" or "updated", never partial
|
||||
content := string(data)
|
||||
if content != "initial" && content != "updated" {
|
||||
readErrors <- fmt.Errorf("read partial data: %s", content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Perform update
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
err = s.fs.Put(ctx, key, strings.NewReader("updated"), nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Stop readers
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(stopReading)
|
||||
wg.Wait()
|
||||
close(readErrors)
|
||||
|
||||
// Check for read errors
|
||||
for err := range readErrors {
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test path sanitization
|
||||
func (s *FilesystemStorageTestSuite) TestPathSanitization() {
|
||||
ctx := context.Background()
|
||||
|
||||
maliciousPaths := []string{
|
||||
"../../../etc/passwd",
|
||||
"/../secret.txt",
|
||||
"./../../outside.txt",
|
||||
"//etc/passwd",
|
||||
}
|
||||
|
||||
for _, path := range maliciousPaths {
|
||||
s.Run(fmt.Sprintf("sanitize_%s", path), func() {
|
||||
err := s.fs.Put(ctx, path, strings.NewReader("malicious"), nil)
|
||||
s.NoError(err) // Should succeed but sanitize path
|
||||
|
||||
// Verify file is inside base directory
|
||||
sanitized := s.fs.keyToPath(path)
|
||||
s.True(strings.HasPrefix(sanitized, s.tempDir),
|
||||
"Sanitized path %s should be inside %s", sanitized, s.tempDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test checksum validation
|
||||
func (s *FilesystemStorageTestSuite) TestChecksumValidation() {
|
||||
ctx := context.Background()
|
||||
|
||||
testData := "checksum test data"
|
||||
// Correct checksums calculated for "checksum test data"
|
||||
correctMD5 := "7dd7323e8ce3e087972f93d3711ef62b"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *storage.PutOptions
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid MD5",
|
||||
opts: &storage.PutOptions{ChecksumMD5: correctMD5},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid MD5",
|
||||
opts: &storage.PutOptions{ChecksumMD5: "invalid"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty checksum (no validation)",
|
||||
opts: &storage.PutOptions{ChecksumMD5: ""},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
key := fmt.Sprintf("checksum/%s.txt", tt.name)
|
||||
err := s.fs.Put(ctx, key, strings.NewReader(testData), tt.opts)
|
||||
|
||||
if tt.expectError {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Put operation
|
||||
func BenchmarkFilesystemPut(b *testing.B) {
|
||||
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fs, _ := New(tempDir, 1024*1024*1024) // 1GB quota
|
||||
defer fs.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
data := strings.Repeat("x", 1024) // 1KB
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("bench/file-%d.txt", i)
|
||||
fs.Put(ctx, key, strings.NewReader(data), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark Get operation
|
||||
func BenchmarkFilesystemGet(b *testing.B) {
|
||||
tempDir, _ := os.MkdirTemp("", "gohoarder-bench-*")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fs, _ := New(tempDir, 1024*1024*1024)
|
||||
defer fs.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
data := strings.Repeat("x", 1024)
|
||||
|
||||
// Setup: Create test file
|
||||
fs.Put(ctx, "bench/test.txt", strings.NewReader(data), nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader, _ := fs.Get(ctx, "bench/test.txt")
|
||||
if reader != nil {
|
||||
io.ReadAll(reader)
|
||||
reader.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StorageBackend defines the interface for package storage
|
||||
type StorageBackend interface {
|
||||
// Get retrieves a package by key
|
||||
Get(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
|
||||
// Put stores a package
|
||||
Put(ctx context.Context, key string, data io.Reader, opts *PutOptions) error
|
||||
|
||||
// Delete removes a package
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// Exists checks if a package exists
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// List lists packages with prefix
|
||||
List(ctx context.Context, prefix string, opts *ListOptions) ([]StorageObject, error)
|
||||
|
||||
// Stat gets package metadata
|
||||
Stat(ctx context.Context, key string) (*StorageInfo, error)
|
||||
|
||||
// GetQuota returns quota information
|
||||
GetQuota(ctx context.Context) (*QuotaInfo, error)
|
||||
|
||||
// Health checks backend health
|
||||
Health(ctx context.Context) error
|
||||
|
||||
// Close closes the backend
|
||||
Close() error
|
||||
}
|
||||
|
||||
// PutOptions contains options for Put operations
|
||||
type PutOptions struct {
|
||||
ContentType string
|
||||
Metadata map[string]string
|
||||
ChecksumMD5 string
|
||||
ChecksumSHA256 string
|
||||
}
|
||||
|
||||
// ListOptions contains options for List operations
|
||||
type ListOptions struct {
|
||||
MaxResults int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// StorageObject represents a stored object
|
||||
type StorageObject struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
ETag string
|
||||
}
|
||||
|
||||
// StorageInfo contains detailed object information
|
||||
type StorageInfo struct {
|
||||
Key string
|
||||
Size int64
|
||||
Modified time.Time
|
||||
ETag string
|
||||
ContentType string
|
||||
Metadata map[string]string
|
||||
Checksums *Checksums
|
||||
}
|
||||
|
||||
// Checksums contains file checksums
|
||||
type Checksums struct {
|
||||
MD5 string
|
||||
SHA256 string
|
||||
}
|
||||
|
||||
// QuotaInfo contains quota information
|
||||
type QuotaInfo struct {
|
||||
Used int64
|
||||
Available int64
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// LocalPathProvider is an optional interface that storage backends can implement
|
||||
// to provide direct file system paths for scanning without creating temp copies
|
||||
type LocalPathProvider interface {
|
||||
// GetLocalPath returns the local filesystem path for a storage key
|
||||
// Returns empty string if the backend doesn't support local paths (e.g., S3, SMB)
|
||||
GetLocalPath(ctx context.Context, key string) (string, error)
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// S3Storage implements storage.StorageBackend for AWS S3
|
||||
type S3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
}
|
||||
|
||||
// Config holds S3 configuration
|
||||
type Config struct {
|
||||
Bucket string
|
||||
Region string
|
||||
Endpoint string // For S3-compatible services (MinIO, etc.)
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
Prefix string // Optional prefix for all keys
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
ForcePathStyle bool // For S3-compatible services
|
||||
}
|
||||
|
||||
// New creates a new S3 storage backend
|
||||
func New(ctx context.Context, cfg Config) (*S3Storage, error) {
|
||||
if cfg.Bucket == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 bucket is required")
|
||||
}
|
||||
|
||||
if cfg.Region == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "S3 region is required")
|
||||
}
|
||||
|
||||
// Build AWS config
|
||||
var awsCfg aws.Config
|
||||
var err error
|
||||
|
||||
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
|
||||
// Use static credentials
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(
|
||||
cfg.AccessKeyID,
|
||||
cfg.SecretAccessKey,
|
||||
"",
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
// Use default credential chain
|
||||
awsCfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(cfg.Region),
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to load AWS config")
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
var s3Options []func(*s3.Options)
|
||||
|
||||
if cfg.Endpoint != "" {
|
||||
s3Options = append(s3Options, func(o *s3.Options) {
|
||||
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
||||
o.UsePathStyle = cfg.ForcePathStyle
|
||||
})
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, s3Options...)
|
||||
|
||||
s3Storage := &S3Storage{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: strings.TrimSuffix(cfg.Prefix, "/"),
|
||||
quota: cfg.Quota,
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := s3Storage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial S3 storage usage")
|
||||
}
|
||||
|
||||
return s3Storage, nil
|
||||
}
|
||||
|
||||
// Get retrieves a file from S3
|
||||
func (s *S3Storage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
result, err := s.client.GetObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
metrics.RecordStorageOperation("s3", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("s3", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get object from S3")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("s3", "get", "success")
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
// Put stores a file in S3
|
||||
func (s *S3Storage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New()
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
}
|
||||
|
||||
// Check quota before upload
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("s3", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("s3", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare metadata
|
||||
metadata := make(map[string]string)
|
||||
if opts != nil && opts.Metadata != nil {
|
||||
metadata = opts.Metadata
|
||||
}
|
||||
|
||||
// Build put input
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
Body: bytes.NewReader(buf.Bytes()),
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
if opts != nil && opts.ContentType != "" {
|
||||
input.ContentType = aws.String(opts.ContentType)
|
||||
}
|
||||
|
||||
// Upload to S3
|
||||
_, err = s.client.PutObject(ctx, input)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to upload to S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "put", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from S3
|
||||
func (s *S3Storage) Delete(ctx context.Context, key string) error {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
// Get size before deletion for quota tracking
|
||||
statInfo, err := s.Stat(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
input := &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
_, err = s.client.DeleteObject(ctx, input)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("s3", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete from S3")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= statInfo.Size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("s3", "delete", "success")
|
||||
metrics.UpdateCacheSize("s3", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists in S3
|
||||
func (s *S3Storage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
_, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check existence in S3")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix in S3
|
||||
func (s *S3Storage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
s3Prefix := s.buildKey(prefix)
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Prefix: aws.String(s3Prefix),
|
||||
}
|
||||
|
||||
var objects []storage.StorageObject
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list objects in S3")
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
key := s.stripPrefix(*obj.Key)
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: *obj.Size,
|
||||
Modified: *obj.LastModified,
|
||||
ETag: strings.Trim(*obj.ETag, "\""),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from S3
|
||||
func (s *S3Storage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
s3Key := s.buildKey(key)
|
||||
|
||||
input := &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s3Key),
|
||||
}
|
||||
|
||||
result, err := s.client.HeadObject(ctx, input)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat object in S3")
|
||||
}
|
||||
|
||||
info := &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: *result.ContentLength,
|
||||
Modified: *result.LastModified,
|
||||
ETag: strings.Trim(*result.ETag, "\""),
|
||||
Metadata: result.Metadata,
|
||||
}
|
||||
|
||||
if result.ContentType != nil {
|
||||
info.ContentType = *result.ContentType
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (s *S3Storage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks S3 health
|
||||
func (s *S3Storage) Health(ctx context.Context) error {
|
||||
// Try to list bucket to verify connectivity
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
MaxKeys: aws.Int32(1),
|
||||
}
|
||||
|
||||
_, err := s.client.ListObjectsV2(ctx, input)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "S3 health check failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (s *S3Storage) Close() error {
|
||||
// No cleanup needed for S3 client
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildKey builds the full S3 key with prefix
|
||||
func (s *S3Storage) buildKey(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
if s.prefix != "" {
|
||||
return s.prefix + "/" + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// stripPrefix removes the configured prefix from an S3 key
|
||||
func (s *S3Storage) stripPrefix(s3Key string) string {
|
||||
if s.prefix != "" {
|
||||
return strings.TrimPrefix(s3Key, s.prefix+"/")
|
||||
}
|
||||
return s3Key
|
||||
}
|
||||
|
||||
// calculateUsage calculates current S3 storage usage
|
||||
func (s *S3Storage) calculateUsage(ctx context.Context) error {
|
||||
var total int64
|
||||
|
||||
input := &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(s.bucket),
|
||||
}
|
||||
|
||||
if s.prefix != "" {
|
||||
input.Prefix = aws.String(s.prefix + "/")
|
||||
}
|
||||
|
||||
paginator := s3.NewListObjectsV2Paginator(s.client, input)
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
total += *obj.Size
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("s3", total)
|
||||
return nil
|
||||
}
|
||||
|
||||
// isNotFoundError checks if an error is a "not found" error
|
||||
func isNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var notFound *types.NotFound
|
||||
var noSuchKey *types.NoSuchKey
|
||||
|
||||
return stderrors.As(err, ¬Found) || stderrors.As(err, &noSuchKey)
|
||||
}
|
||||
@@ -0,0 +1,579 @@
|
||||
package smb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/errors"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
|
||||
"github.com/lukaszraczylo/gohoarder/pkg/storage"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SMBStorage implements storage.StorageBackend for SMB/CIFS shares
|
||||
type SMBStorage struct {
|
||||
host string
|
||||
share string
|
||||
basePath string
|
||||
username string
|
||||
password string
|
||||
quota int64
|
||||
mu sync.RWMutex
|
||||
used int64
|
||||
connPool chan *smbConnection
|
||||
poolSize int
|
||||
}
|
||||
|
||||
// smbConnection wraps an SMB session and share
|
||||
type smbConnection struct {
|
||||
conn net.Conn
|
||||
session *smb2.Session
|
||||
share *smb2.Share
|
||||
lastUse time.Time
|
||||
}
|
||||
|
||||
// Config holds SMB configuration
|
||||
type Config struct {
|
||||
Host string // SMB server hostname or IP
|
||||
Port int // SMB server port (default: 445)
|
||||
Share string // SMB share name
|
||||
BasePath string // Base path within the share
|
||||
Username string // SMB username
|
||||
Password string // SMB password
|
||||
Domain string // SMB domain (optional)
|
||||
Quota int64 // Quota in bytes (0 = unlimited)
|
||||
PoolSize int // Connection pool size (default: 5)
|
||||
}
|
||||
|
||||
// New creates a new SMB storage backend
|
||||
func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required")
|
||||
}
|
||||
|
||||
if cfg.Share == "" {
|
||||
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required")
|
||||
}
|
||||
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 445 // Default SMB port
|
||||
}
|
||||
|
||||
if cfg.PoolSize == 0 {
|
||||
cfg.PoolSize = 5 // Default pool size
|
||||
}
|
||||
|
||||
smbStorage := &SMBStorage{
|
||||
host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
share: cfg.Share,
|
||||
basePath: strings.Trim(cfg.BasePath, "/\\"),
|
||||
username: cfg.Username,
|
||||
password: cfg.Password,
|
||||
quota: cfg.Quota,
|
||||
connPool: make(chan *smbConnection, cfg.PoolSize),
|
||||
poolSize: cfg.PoolSize,
|
||||
}
|
||||
|
||||
// Initialize connection pool
|
||||
for i := 0; i < cfg.PoolSize; i++ {
|
||||
conn, err := smbStorage.createConnection(ctx)
|
||||
if err != nil {
|
||||
// Clean up any created connections
|
||||
close(smbStorage.connPool)
|
||||
for c := range smbStorage.connPool {
|
||||
c.close()
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool")
|
||||
}
|
||||
smbStorage.connPool <- conn
|
||||
}
|
||||
|
||||
// Calculate initial usage
|
||||
if err := smbStorage.calculateUsage(ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage")
|
||||
}
|
||||
|
||||
return smbStorage, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new SMB connection
|
||||
func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) {
|
||||
conn, err := net.Dial("tcp", s.host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: s.username,
|
||||
Password: s.password,
|
||||
},
|
||||
}
|
||||
|
||||
session, err := dialer.Dial(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
share, err := session.Mount(s.share)
|
||||
if err != nil {
|
||||
session.Logoff()
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smbConnection{
|
||||
conn: conn,
|
||||
session: session,
|
||||
share: share,
|
||||
lastUse: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getConnection gets a connection from the pool
|
||||
func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) {
|
||||
select {
|
||||
case conn := <-s.connPool:
|
||||
conn.lastUse = time.Now()
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection")
|
||||
}
|
||||
}
|
||||
|
||||
// returnConnection returns a connection to the pool
|
||||
func (s *SMBStorage) returnConnection(conn *smbConnection) {
|
||||
select {
|
||||
case s.connPool <- conn:
|
||||
default:
|
||||
// Pool is full, close the connection
|
||||
conn.close()
|
||||
}
|
||||
}
|
||||
|
||||
// close closes an SMB connection
|
||||
func (c *smbConnection) close() {
|
||||
if c.share != nil {
|
||||
c.share.Umount()
|
||||
}
|
||||
if c.session != nil {
|
||||
c.session.Logoff()
|
||||
}
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a file from SMB share
|
||||
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
// Open file
|
||||
file, err := conn.share.Open(path)
|
||||
if err != nil {
|
||||
s.returnConnection(conn)
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "get", "not_found")
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
|
||||
}
|
||||
|
||||
// Read entire file into memory and close SMB connection
|
||||
// This is necessary because we need to return the connection to the pool
|
||||
data, err := io.ReadAll(file)
|
||||
file.Close()
|
||||
s.returnConnection(conn)
|
||||
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "get", "error")
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
|
||||
}
|
||||
|
||||
metrics.RecordStorageOperation("smb", "get", "success")
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
// Put stores a file on SMB share
|
||||
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
dir := filepath.Dir(path)
|
||||
|
||||
// Create directory structure
|
||||
if err := conn.share.MkdirAll(dir, 0755); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
|
||||
}
|
||||
|
||||
// Read data into buffer to calculate checksums and size
|
||||
var buf bytes.Buffer
|
||||
md5Hash := md5.New()
|
||||
sha256Hash := sha256.New()
|
||||
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
|
||||
|
||||
written, err := io.Copy(multiWriter, data)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if s.quota > 0 {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
if used+written > s.quota {
|
||||
metrics.RecordStorageOperation("smb", "put", "quota_exceeded")
|
||||
return errors.QuotaExceeded(s.quota)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify checksums if provided
|
||||
if opts != nil {
|
||||
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
|
||||
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
|
||||
|
||||
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
|
||||
}
|
||||
|
||||
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
|
||||
metrics.RecordStorageOperation("smb", "put", "checksum_error")
|
||||
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// Create temp file for atomic write
|
||||
tempPath := path + ".tmp"
|
||||
file, err := conn.share.Create(tempPath)
|
||||
if err != nil {
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file")
|
||||
}
|
||||
|
||||
// Write data
|
||||
_, err = io.Copy(file, bytes.NewReader(buf.Bytes()))
|
||||
file.Close()
|
||||
|
||||
if err != nil {
|
||||
conn.share.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := conn.share.Rename(tempPath, path); err != nil {
|
||||
conn.share.Remove(tempPath)
|
||||
metrics.RecordStorageOperation("smb", "put", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used += written
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "put", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a file from SMB share
|
||||
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
// Get size before deletion
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
metrics.RecordStorageOperation("smb", "delete", "not_found")
|
||||
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
|
||||
size := info.Size()
|
||||
|
||||
if err := conn.share.Remove(path); err != nil {
|
||||
metrics.RecordStorageOperation("smb", "delete", "error")
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
|
||||
}
|
||||
|
||||
// Update usage
|
||||
s.mu.Lock()
|
||||
s.used -= size
|
||||
currentUsed := s.used
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.RecordStorageOperation("smb", "delete", "success")
|
||||
metrics.UpdateCacheSize("smb", currentUsed)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a file exists on SMB share
|
||||
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
_, err = conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check SMB file existence")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// List lists files with prefix on SMB share
|
||||
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
searchPath := s.keyToPath(prefix)
|
||||
var objects []storage.StorageObject
|
||||
|
||||
err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := s.pathToKey(path)
|
||||
objects = append(objects, storage.StorageObject{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
|
||||
}
|
||||
|
||||
// Apply pagination if requested
|
||||
if opts != nil {
|
||||
start := opts.Offset
|
||||
end := len(objects)
|
||||
if opts.MaxResults > 0 && start+opts.MaxResults < end {
|
||||
end = start + opts.MaxResults
|
||||
}
|
||||
if start < len(objects) {
|
||||
objects = objects[start:end]
|
||||
} else {
|
||||
objects = []storage.StorageObject{}
|
||||
}
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Stat gets file metadata from SMB share
|
||||
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
path := s.keyToPath(key)
|
||||
|
||||
info, err := conn.share.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
|
||||
}
|
||||
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
|
||||
}
|
||||
|
||||
return &storage.StorageInfo{
|
||||
Key: key,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQuota returns quota information
|
||||
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
|
||||
s.mu.RLock()
|
||||
used := s.used
|
||||
s.mu.RUnlock()
|
||||
|
||||
available := s.quota - used
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return &storage.QuotaInfo{
|
||||
Used: used,
|
||||
Available: available,
|
||||
Limit: s.quota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Health checks SMB health
|
||||
func (s *SMBStorage) Health(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error")
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
// Try to stat the base path
|
||||
path := s.keyToPath("")
|
||||
_, err = conn.share.Stat(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the storage backend
|
||||
func (s *SMBStorage) Close() error {
|
||||
close(s.connPool)
|
||||
for conn := range s.connPool {
|
||||
conn.close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// keyToPath converts a storage key to SMB path
|
||||
func (s *SMBStorage) keyToPath(key string) string {
|
||||
key = strings.TrimPrefix(key, "/")
|
||||
key = filepath.Clean(key)
|
||||
|
||||
// Remove path traversal attempts
|
||||
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = strings.TrimPrefix(key, "../")
|
||||
key = strings.TrimPrefix(key, "..\\")
|
||||
}
|
||||
|
||||
key = filepath.Clean(key)
|
||||
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
|
||||
key = ""
|
||||
}
|
||||
|
||||
if s.basePath != "" {
|
||||
return filepath.Join(s.basePath, key)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// pathToKey converts an SMB path back to a storage key
|
||||
func (s *SMBStorage) pathToKey(path string) string {
|
||||
if s.basePath != "" {
|
||||
path = strings.TrimPrefix(path, s.basePath)
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
path = strings.TrimPrefix(path, "\\")
|
||||
}
|
||||
return filepath.ToSlash(path)
|
||||
}
|
||||
|
||||
// walkPath recursively walks an SMB directory
|
||||
func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error {
|
||||
info, err := share.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return fn(path, info)
|
||||
}
|
||||
|
||||
entries, err := share.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(path, entry.Name())
|
||||
if entry.IsDir() {
|
||||
if err := s.walkPath(share, entryPath, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := fn(entryPath, entry); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateUsage calculates current SMB storage usage
|
||||
func (s *SMBStorage) calculateUsage(ctx context.Context) error {
|
||||
conn, err := s.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.returnConnection(conn)
|
||||
|
||||
var total int64
|
||||
basePath := s.keyToPath("")
|
||||
|
||||
err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error {
|
||||
if !info.IsDir() {
|
||||
total += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.used = total
|
||||
s.mu.Unlock()
|
||||
|
||||
metrics.UpdateCacheSize("smb", total)
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user