Files
gohoarder/pkg/storage/smb/smb.go
T
2026-01-02 18:20:15 +00:00

580 lines
14 KiB
Go

package smb
import (
"bytes"
"context"
"crypto/md5" // #nosec G501 -- MD5 used for file checksums, not cryptographic security
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/hirochachacha/go-smb2"
"github.com/lukaszraczylo/gohoarder/pkg/errors"
"github.com/lukaszraczylo/gohoarder/pkg/metrics"
"github.com/lukaszraczylo/gohoarder/pkg/storage"
"github.com/rs/zerolog/log"
)
// SMBStorage implements storage.StorageBackend for SMB/CIFS shares
type SMBStorage struct {
host string
share string
basePath string
username string
password string
quota int64
mu sync.RWMutex
used int64
connPool chan *smbConnection
poolSize int
}
// smbConnection wraps an SMB session and share
type smbConnection struct {
conn net.Conn
session *smb2.Session
share *smb2.Share
lastUse time.Time
}
// Config holds SMB configuration
type Config struct {
Host string // SMB server hostname or IP
Port int // SMB server port (default: 445)
Share string // SMB share name
BasePath string // Base path within the share
Username string // SMB username
Password string // SMB password
Domain string // SMB domain (optional)
Quota int64 // Quota in bytes (0 = unlimited)
PoolSize int // Connection pool size (default: 5)
}
// New creates a new SMB storage backend
func New(ctx context.Context, cfg Config) (*SMBStorage, error) {
if cfg.Host == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB host is required")
}
if cfg.Share == "" {
return nil, errors.New(errors.ErrCodeInvalidConfig, "SMB share is required")
}
if cfg.Port == 0 {
cfg.Port = 445 // Default SMB port
}
if cfg.PoolSize == 0 {
cfg.PoolSize = 5 // Default pool size
}
smbStorage := &SMBStorage{
host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
share: cfg.Share,
basePath: strings.Trim(cfg.BasePath, "/\\"),
username: cfg.Username,
password: cfg.Password,
quota: cfg.Quota,
connPool: make(chan *smbConnection, cfg.PoolSize),
poolSize: cfg.PoolSize,
}
// Initialize connection pool
for i := 0; i < cfg.PoolSize; i++ {
conn, err := smbStorage.createConnection(ctx)
if err != nil {
// Clean up any created connections
close(smbStorage.connPool)
for c := range smbStorage.connPool {
c.close()
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB connection pool")
}
smbStorage.connPool <- conn
}
// Calculate initial usage
if err := smbStorage.calculateUsage(ctx); err != nil {
log.Warn().Err(err).Msg("Failed to calculate initial SMB storage usage")
}
return smbStorage, nil
}
// createConnection creates a new SMB connection
func (s *SMBStorage) createConnection(ctx context.Context) (*smbConnection, error) {
conn, err := net.Dial("tcp", s.host)
if err != nil {
return nil, err
}
dialer := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: s.username,
Password: s.password,
},
}
session, err := dialer.Dial(conn)
if err != nil {
conn.Close() // #nosec G104 -- Cleanup, error not critical
return nil, err
}
share, err := session.Mount(s.share)
if err != nil {
_ = session.Logoff() // #nosec G104 -- SMB cleanup
conn.Close() // #nosec G104 -- Cleanup, error not critical
return nil, err
}
return &smbConnection{
conn: conn,
session: session,
share: share,
lastUse: time.Now(),
}, nil
}
// getConnection gets a connection from the pool
func (s *SMBStorage) getConnection(ctx context.Context) (*smbConnection, error) {
select {
case conn := <-s.connPool:
conn.lastUse = time.Now()
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(30 * time.Second):
return nil, errors.New(errors.ErrCodeStorageFailure, "timeout waiting for SMB connection")
}
}
// returnConnection returns a connection to the pool
func (s *SMBStorage) returnConnection(conn *smbConnection) {
select {
case s.connPool <- conn:
default:
// Pool is full, close the connection
conn.close()
}
}
// close closes an SMB connection
func (c *smbConnection) close() {
if c.share != nil {
_ = c.share.Umount() // #nosec G104 -- SMB cleanup
}
if c.session != nil {
_ = c.session.Logoff() // #nosec G104 -- SMB cleanup
}
if c.conn != nil {
c.conn.Close() // #nosec G104 -- Cleanup, error not critical
}
}
// Get retrieves a file from SMB share
func (s *SMBStorage) Get(ctx context.Context, key string) (io.ReadCloser, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
path := s.keyToPath(key)
// Open file
file, err := conn.share.Open(path)
if err != nil {
s.returnConnection(conn)
if os.IsNotExist(err) {
metrics.RecordStorageOperation("smb", "get", "not_found")
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("smb", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to open SMB file")
}
// Read entire file into memory and close SMB connection
// This is necessary because we need to return the connection to the pool
data, err := io.ReadAll(file)
file.Close() // #nosec G104 -- Cleanup, error not critical
s.returnConnection(conn)
if err != nil {
metrics.RecordStorageOperation("smb", "get", "error")
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read SMB file")
}
metrics.RecordStorageOperation("smb", "get", "success")
return io.NopCloser(bytes.NewReader(data)), nil
}
// Put stores a file on SMB share
func (s *SMBStorage) Put(ctx context.Context, key string, data io.Reader, opts *storage.PutOptions) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
dir := filepath.Dir(path)
// Create directory structure
if err := conn.share.MkdirAll(dir, 0755); err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB directory")
}
// Read data into buffer to calculate checksums and size
var buf bytes.Buffer
md5Hash := md5.New() // #nosec G401 -- MD5 used for file integrity check, not cryptographic security
sha256Hash := sha256.New()
multiWriter := io.MultiWriter(&buf, md5Hash, sha256Hash)
written, err := io.Copy(multiWriter, data)
if err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to read data")
}
// Check quota
if s.quota > 0 {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
if used+written > s.quota {
metrics.RecordStorageOperation("smb", "put", "quota_exceeded")
return errors.QuotaExceeded(s.quota)
}
}
// Verify checksums if provided
if opts != nil {
md5Sum := hex.EncodeToString(md5Hash.Sum(nil))
sha256Sum := hex.EncodeToString(sha256Hash.Sum(nil))
if opts.ChecksumMD5 != "" && opts.ChecksumMD5 != md5Sum {
metrics.RecordStorageOperation("smb", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "MD5 checksum mismatch")
}
if opts.ChecksumSHA256 != "" && opts.ChecksumSHA256 != sha256Sum {
metrics.RecordStorageOperation("smb", "put", "checksum_error")
return errors.New(errors.ErrCodeChecksumMismatch, "SHA256 checksum mismatch")
}
}
// Create temp file for atomic write
tempPath := path + ".tmp"
file, err := conn.share.Create(tempPath)
if err != nil {
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to create SMB temp file")
}
// Write data
_, err = io.Copy(file, bytes.NewReader(buf.Bytes()))
file.Close() // #nosec G104 -- Cleanup, error not critical
if err != nil {
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to write SMB file")
}
// Atomic rename
if err := conn.share.Rename(tempPath, path); err != nil {
_ = conn.share.Remove(tempPath) // #nosec G104 -- SMB cleanup
metrics.RecordStorageOperation("smb", "put", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to rename SMB temp file")
}
// Update usage
s.mu.Lock()
s.used += written
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("smb", "put", "success")
metrics.UpdateCacheSize("smb", currentUsed)
return nil
}
// Delete removes a file from SMB share
func (s *SMBStorage) Delete(ctx context.Context, key string) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
// Get size before deletion
info, err := conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
metrics.RecordStorageOperation("smb", "delete", "not_found")
return errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
metrics.RecordStorageOperation("smb", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
}
size := info.Size()
if err := conn.share.Remove(path); err != nil {
metrics.RecordStorageOperation("smb", "delete", "error")
return errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to delete SMB file")
}
// Update usage
s.mu.Lock()
s.used -= size
currentUsed := s.used
s.mu.Unlock()
metrics.RecordStorageOperation("smb", "delete", "success")
metrics.UpdateCacheSize("smb", currentUsed)
return nil
}
// Exists checks if a file exists on SMB share
func (s *SMBStorage) Exists(ctx context.Context, key string) (bool, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return false, err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
_, err = conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to check SMB file existence")
}
return true, nil
}
// List lists files with prefix on SMB share
func (s *SMBStorage) List(ctx context.Context, prefix string, opts *storage.ListOptions) ([]storage.StorageObject, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
defer s.returnConnection(conn)
searchPath := s.keyToPath(prefix)
var objects []storage.StorageObject
err = s.walkPath(conn.share, searchPath, func(path string, info os.FileInfo) error {
if info.IsDir() {
return nil
}
key := s.pathToKey(path)
objects = append(objects, storage.StorageObject{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
})
return nil
})
if err != nil && !os.IsNotExist(err) {
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to list SMB files")
}
// Apply pagination if requested
if opts != nil {
start := opts.Offset
end := len(objects)
if opts.MaxResults > 0 && start+opts.MaxResults < end {
end = start + opts.MaxResults
}
if start < len(objects) {
objects = objects[start:end]
} else {
objects = []storage.StorageObject{}
}
}
return objects, nil
}
// Stat gets file metadata from SMB share
func (s *SMBStorage) Stat(ctx context.Context, key string) (*storage.StorageInfo, error) {
conn, err := s.getConnection(ctx)
if err != nil {
return nil, err
}
defer s.returnConnection(conn)
path := s.keyToPath(key)
info, err := conn.share.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, errors.NotFound(fmt.Sprintf("file not found: %s", key))
}
return nil, errors.Wrap(err, errors.ErrCodeStorageFailure, "failed to stat SMB file")
}
return &storage.StorageInfo{
Key: key,
Size: info.Size(),
Modified: info.ModTime(),
}, nil
}
// GetQuota returns quota information
func (s *SMBStorage) GetQuota(ctx context.Context) (*storage.QuotaInfo, error) {
s.mu.RLock()
used := s.used
s.mu.RUnlock()
available := s.quota - used
if available < 0 {
available = 0
}
return &storage.QuotaInfo{
Used: used,
Available: available,
Limit: s.quota,
}, nil
}
// Health checks SMB health
func (s *SMBStorage) Health(ctx context.Context) error {
conn, err := s.getConnection(ctx)
if err != nil {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed - connection error")
}
defer s.returnConnection(conn)
// Try to stat the base path
path := s.keyToPath("")
_, err = conn.share.Stat(path)
if err != nil && !os.IsNotExist(err) {
return errors.Wrap(err, errors.ErrCodeStorageFailure, "SMB health check failed")
}
return nil
}
// Close closes the storage backend
func (s *SMBStorage) Close() error {
close(s.connPool)
for conn := range s.connPool {
conn.close()
}
return nil
}
// keyToPath converts a storage key to SMB path
func (s *SMBStorage) keyToPath(key string) string {
key = strings.TrimPrefix(key, "/")
key = filepath.Clean(key)
// Remove path traversal attempts
for strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = strings.TrimPrefix(key, "../")
key = strings.TrimPrefix(key, "..\\")
}
key = filepath.Clean(key)
if key == ".." || strings.HasPrefix(key, "../") || strings.HasPrefix(key, "..\\") {
key = ""
}
if s.basePath != "" {
return filepath.Join(s.basePath, key)
}
return key
}
// pathToKey converts an SMB path back to a storage key
func (s *SMBStorage) pathToKey(path string) string {
if s.basePath != "" {
path = strings.TrimPrefix(path, s.basePath)
path = strings.TrimPrefix(path, "/")
path = strings.TrimPrefix(path, "\\")
}
return filepath.ToSlash(path)
}
// walkPath recursively walks an SMB directory
func (s *SMBStorage) walkPath(share *smb2.Share, path string, fn func(string, os.FileInfo) error) error {
info, err := share.Stat(path)
if err != nil {
return err
}
if !info.IsDir() {
return fn(path, info)
}
entries, err := share.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
entryPath := filepath.Join(path, entry.Name())
if entry.IsDir() {
if err := s.walkPath(share, entryPath, fn); err != nil {
return err
}
} else {
if err := fn(entryPath, entry); err != nil {
return err
}
}
}
return nil
}
// calculateUsage calculates current SMB storage usage
func (s *SMBStorage) calculateUsage(ctx context.Context) error {
conn, err := s.getConnection(ctx)
if err != nil {
return err
}
defer s.returnConnection(conn)
var total int64
basePath := s.keyToPath("")
err = s.walkPath(conn.share, basePath, func(path string, info os.FileInfo) error {
if !info.IsDir() {
total += info.Size()
}
return nil
})
if err != nil && !os.IsNotExist(err) {
return err
}
s.mu.Lock()
s.used = total
s.mu.Unlock()
metrics.UpdateCacheSize("smb", total)
return nil
}