Files
gohoarder/pkg/storage/smb/smb.go
T
lukaszraczylo 6b037a92b4 refactor: reorganize struct fields, add new handlers and storage backends
- [x] Reorder struct fields across codebase for consistency
- [x] Add analytics event handlers and tests
- [x] Add authentication API key management handlers and tests
- [x] Add pre-warming control handlers and tests
- [x] Implement S3 storage backend with tests
- [x] Implement SMB/CIFS storage backend with tests
- [x] Add CDN middleware tests
- [x] Integrate analytics tracking into cache manager
- [x] Add S3 and SMB storage initialization in app setup
- [x] Add CDN caching to proxy handlers
- [x] Remove distributed locking (Redis lock manager)
- [x] Remove proxy common package and utilities
- [x] Remove standalone HTTP server package
- [x] Remove logger middleware
- [x] Simplify error handling utilities
- [x] Update config with S3 and SMB options
- [x] Update cache manager signature to include analytics
2026-01-03 00:18:58 +00:00

543 lines
13 KiB
Go

package smb
import (
"context"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/hirochachacha/go-smb2"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/rs/zerolog/log"
)
// Config holds SMB storage configuration
type Config struct {
Host string
Share string
Path string
Username string
Password string
Domain string
Port int
MaxSizeBytes int64
PoolSize int
}
// SMBStorage implements storage.StorageBackend using SMB/CIFS
type SMBStorage struct {
connPool chan *smbConnection
config Config
maxSizeBytes int64
poolSize int
}
// smbConnection represents a pooled SMB connection
type smbConnection struct {
conn net.Conn
session *smb2.Session
share *smb2.Share
lastUse time.Time
}
// New creates a new SMB storage backend
func New(cfg Config) (*SMBStorage, error) {
if cfg.Host == "" {
return nil, fmt.Errorf("SMB host is required")
}
if cfg.Share == "" {
return nil, fmt.Errorf("SMB share is required")
}
if cfg.Port == 0 {
cfg.Port = 445 // Default SMB port
}
if cfg.PoolSize == 0 {
cfg.PoolSize = 5 // Default pool size
}
// Normalize path
cfg.Path = strings.Trim(cfg.Path, "/\\")
storage := &SMBStorage{
config: cfg,
maxSizeBytes: cfg.MaxSizeBytes,
poolSize: cfg.PoolSize,
connPool: make(chan *smbConnection, cfg.PoolSize),
}
// Pre-populate connection pool
for i := 0; i < cfg.PoolSize; i++ {
conn, err := storage.createConnection()
if err != nil {
log.Warn().Err(err).Int("attempt", i).Msg("Failed to create initial SMB connection")
continue
}
storage.connPool <- conn
}
log.Info().
Str("host", cfg.Host).
Int("port", cfg.Port).
Str("share", cfg.Share).
Str("path", cfg.Path).
Int("pool_size", cfg.PoolSize).
Msg("SMB storage initialized")
return storage, nil
}
// createConnection creates a new SMB connection
func (s *SMBStorage) createConnection() (*smbConnection, error) {
// Connect to SMB server (use net.JoinHostPort for IPv6 compatibility)
addr := net.JoinHostPort(s.config.Host, fmt.Sprintf("%d", s.config.Port))
conn, err := net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMB server: %w", err)
}
// Create SMB dialer
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: s.config.Username,
Password: s.config.Password,
Domain: s.config.Domain,
},
}
// Establish SMB session
session, err := d.Dial(conn)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to establish SMB session: %w", err)
}
// Mount share
share, err := session.Mount(s.config.Share)
if err != nil {
_ = session.Logoff()
_ = conn.Close()
return nil, fmt.Errorf("failed to mount SMB share: %w", err)
}
return &smbConnection{
conn: conn,
session: session,
share: share,
lastUse: time.Now(),
}, nil
}
// getConnection gets a connection from the pool or creates a new one
func (s *SMBStorage) getConnection() (*smbConnection, error) {
select {
case conn := <-s.connPool:
// Check if connection is still valid (not older than 5 minutes idle)
if time.Since(conn.lastUse) > 5*time.Minute {
conn.close()
return s.createConnection()
}
conn.lastUse = time.Now()
return conn, nil
default:
// Pool is empty, create new connection
return s.createConnection()
}
}
// returnConnection returns a connection to the pool
func (s *SMBStorage) returnConnection(conn *smbConnection) {
if conn == nil {
return
}
select {
case s.connPool <- conn:
// Successfully returned to pool
default:
// Pool is full, close connection
conn.close()
}
}
// close closes an SMB connection
func (c *smbConnection) close() {
if c.share != nil {
if err := c.share.Umount(); err != nil {
log.Warn().Err(err).Msg("Failed to unmount SMB share")
}
}
if c.session != nil {
if err := c.session.Logoff(); err != nil {
log.Warn().Err(err).Msg("Failed to logoff SMB session")
}
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Warn().Err(err).Msg("Failed to close SMB connection")
}
}
}
// Get retrieves data from SMB share
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
conn, err := s.getConnection()
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
log.Debug().Str("key", path).Msg("Getting file from SMB")
// Open file
file, err := conn.share.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
}
// Read entire file into memory (SMB files must be read completely before closing connection)
data, err := io.ReadAll(file)
if closeErr := file.Close(); closeErr != nil {
log.Warn().Err(closeErr).Str("path", path).Msg("Failed to close SMB file after reading")
}
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
}
// Return as ReadCloser
return io.NopCloser(strings.NewReader(string(data))), nil
}
// Put stores data on SMB share
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
conn, err := s.getConnection()
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
log.Debug().Str("key", path).Msg("Putting file to SMB")
// Ensure directory exists
dir := filepath.Dir(path)
if err := s.ensureDir(conn, dir); err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
}
// Read data into buffer to check quota
buf := new(strings.Builder)
size, err := io.Copy(buf, data)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
}
// Check quota if set
if s.maxSizeBytes > 0 {
currentUsage, err := s.calculateUsage(conn)
if err != nil {
log.Warn().Err(err).Msg("Failed to calculate current usage, skipping quota check")
} else if currentUsage+size > s.maxSizeBytes {
return errors.QuotaExceeded(s.maxSizeBytes)
}
}
// Create/overwrite file
file, err := conn.share.Create(path)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB file")
}
defer file.Close()
// Write data
_, err = file.Write([]byte(buf.String()))
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
}
return nil
}
// ensureDir ensures a directory exists on SMB share
func (s *SMBStorage) ensureDir(conn *smbConnection, path string) error {
if path == "" || path == "." || path == "/" {
return nil
}
// Try to stat the directory
_, err := conn.share.Stat(path)
if err == nil {
return nil // Directory exists
}
// Create parent directory first
parent := filepath.Dir(path)
if parent != path && parent != "." && parent != "/" {
if err := s.ensureDir(conn, parent); err != nil {
return err
}
}
// Create this directory
err = conn.share.Mkdir(path, 0755)
if err != nil && !os.IsExist(err) {
return err
}
return nil
}
// Delete removes data from SMB share
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
conn, err := s.getConnection()
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
log.Debug().Str("key", path).Msg("Deleting file from SMB")
err = conn.share.Remove(path)
if err != nil && !os.IsNotExist(err) {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
}
return nil
}
// Exists checks if data exists on SMB share
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
conn, err := s.getConnection()
if err != nil {
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
_, err = conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check SMB file existence")
}
return true, nil
}
// List returns a list of objects with the given prefix
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
conn, err := s.getConnection()
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
basePath := s.keyToPath(prefix)
log.Debug().Str("prefix", basePath).Msg("Listing files in SMB")
var objects []storage.StorageObject
// Walk the directory tree
err = s.walkPath(conn, basePath, func(path string, info os.FileInfo) error {
if info.IsDir() {
return nil
}
// Convert path back to key
key := s.pathToKey(path)
objects = append(objects, storage.StorageObject{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
})
return nil
})
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
}
return objects, nil
}
// walkPath walks a directory tree on SMB share
func (s *SMBStorage) walkPath(conn *smbConnection, root string, fn func(string, os.FileInfo) error) error {
// Check if root exists
info, err := conn.share.Stat(root)
if err != nil {
if os.IsNotExist(err) {
return nil // Empty directory
}
return err
}
// If root is a file, process it directly
if !info.IsDir() {
return fn(root, info)
}
// List directory contents
entries, err := conn.share.ReadDir(root)
if err != nil {
return err
}
for _, entry := range entries {
fullPath := filepath.Join(root, entry.Name())
if err := fn(fullPath, entry); err != nil {
return err
}
// Recurse into subdirectories
if entry.IsDir() {
if err := s.walkPath(conn, fullPath, fn); err != nil {
return err
}
}
}
return nil
}
// Stat returns metadata about stored data
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
conn, err := s.getConnection()
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
info, err := conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.NotFound(fmt.Sprintf("SMB file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
}
return &storage.StorageInfo{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
}, nil
}
// GetQuota returns current usage and quota information
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
conn, err := s.getConnection()
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to get SMB connection")
}
defer s.returnConnection(conn)
usage, err := s.calculateUsage(conn)
if err != nil {
return nil, err
}
return &storage.QuotaInfo{
Used: usage,
Limit: s.maxSizeBytes,
}, nil
}
// Health checks if the SMB backend is healthy
func (s *SMBStorage) Health(ctx context.Context) error {
conn, err := s.getConnection()
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed: cannot get connection")
}
defer s.returnConnection(conn)
// Try to stat the base path
_, err = conn.share.Stat(s.config.Path)
if err != nil && !os.IsNotExist(err) {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
}
return nil
}
// Close closes the SMB storage backend
func (s *SMBStorage) Close() error {
close(s.connPool)
// Close all connections in pool
for conn := range s.connPool {
conn.close()
}
log.Info().Msg("SMB storage closed")
return nil
}
// keyToPath converts a storage key to SMB path
func (s *SMBStorage) keyToPath(key string) string {
// Normalize separators to backslash for SMB
key = strings.ReplaceAll(key, "/", "\\")
if s.config.Path == "" {
return key
}
// Use backslash for SMB paths
return s.config.Path + "\\" + key
}
// pathToKey converts an SMB path to storage key
func (s *SMBStorage) pathToKey(path string) string {
// Remove base path
if s.config.Path != "" {
path = strings.TrimPrefix(path, s.config.Path+"\\")
}
// Convert backslashes to forward slashes for consistency
return strings.ReplaceAll(path, "\\", "/")
}
// calculateUsage calculates total storage usage
func (s *SMBStorage) calculateUsage(conn *smbConnection) (int64, error) {
var totalSize int64
basePath := s.config.Path
if basePath == "" {
basePath = "\\"
}
err := s.walkPath(conn, basePath, func(path string, info os.FileInfo) error {
if !info.IsDir() {
totalSize += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate usage: %w", err)
}
return totalSize, nil
}