Files
claude-mnemonic/internal/db/gorm/session_store.go
T
lukaszraczylo d04b60517a Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
2026-01-11 01:51:20 +00:00

220 lines
6.2 KiB
Go

// Package gorm provides GORM-based database operations for claude-mnemonic.
package gorm
import (
"context"
"database/sql"
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
)
// SessionStore provides session-related database operations using GORM.
type SessionStore struct {
db *gorm.DB
}
// NewSessionStore creates a new session store.
func NewSessionStore(store *Store) *SessionStore {
return &SessionStore{db: store.DB}
}
// CreateSDKSession creates a new SDK session (idempotent - returns existing ID if exists).
// This is the KEY to how claude-mnemonic stays unified across hooks.
func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, project, userPrompt string) (int64, error) {
now := time.Now()
session := &SDKSession{
ClaudeSessionID: claudeSessionID,
SDKSessionID: func() sql.NullString {
return sql.NullString{String: claudeSessionID, Valid: true}
}(),
Project: project,
UserPrompt: func() sql.NullString {
if userPrompt != "" {
return sql.NullString{String: userPrompt, Valid: true}
}
return sql.NullString{Valid: false}
}(),
Status: "active",
StartedAt: now.Format(time.RFC3339),
StartedAtEpoch: now.UnixMilli(),
}
// CRITICAL: INSERT OR IGNORE makes this idempotent
// Use OnConflict with DoNothing to achieve INSERT OR IGNORE behavior
result := s.db.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "claude_session_id"}},
DoNothing: true,
}).
Create(session)
if result.Error != nil {
return 0, result.Error
}
// Check if insert happened
if result.RowsAffected == 0 {
// Session exists - UPDATE project and user_prompt if we have non-empty values
if project != "" {
updates := map[string]interface{}{
"project": project,
}
if userPrompt != "" {
updates["user_prompt"] = userPrompt
}
if err := s.db.WithContext(ctx).
Model(&SDKSession{}).
Where("claude_session_id = ?", claudeSessionID).
Updates(updates).Error; err != nil {
return 0, fmt.Errorf("failed to update session: %w", err)
}
}
// Fetch existing session
var existing SDKSession
err := s.db.WithContext(ctx).
Where("claude_session_id = ?", claudeSessionID).
First(&existing).Error
if err != nil {
return 0, err
}
return existing.ID, nil
}
return session.ID, nil
}
// GetSessionByID retrieves a session by its database ID.
func (s *SessionStore) GetSessionByID(ctx context.Context, id int64) (*models.SDKSession, error) {
var sess SDKSession
err := s.db.WithContext(ctx).First(&sess, id).Error
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, err
}
return toModelSDKSession(&sess), nil
}
// FindAnySDKSession finds any session by Claude session ID (any status).
func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID string) (*models.SDKSession, error) {
var sess SDKSession
err := s.db.WithContext(ctx).
Where("claude_session_id = ?", claudeSessionID).
First(&sess).Error
if err == gorm.ErrRecordNotFound {
return nil, nil
}
if err != nil {
return nil, err
}
return toModelSDKSession(&sess), nil
}
// IncrementPromptCounter increments the prompt counter and returns the new value.
// Uses a single SQL query with RETURNING clause for optimal performance.
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
// Use raw SQL with RETURNING to get updated value in single query
// SQLite supports RETURNING since version 3.35.0 (2021-03-12)
var newCounter int
err := s.db.WithContext(ctx).Raw(`
UPDATE sdk_sessions
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
WHERE id = ?
RETURNING prompt_counter
`, id).Scan(&newCounter).Error
if err != nil {
// Fallback for older SQLite versions without RETURNING support
if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 {
// Atomic increment
updateErr := s.db.WithContext(ctx).
Model(&SDKSession{}).
Where("id = ?", id).
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
if updateErr != nil {
return 0, updateErr
}
// Fetch updated value
var sess SDKSession
fetchErr := s.db.WithContext(ctx).
Select("prompt_counter").
First(&sess, id).Error
if fetchErr != nil {
return 0, fetchErr
}
return sess.PromptCounter, nil
}
return 0, err
}
return newCounter, nil
}
// GetPromptCounter returns the current prompt counter for a session.
func (s *SessionStore) GetPromptCounter(ctx context.Context, id int64) (int, error) {
var sess SDKSession
err := s.db.WithContext(ctx).
Select("prompt_counter").
First(&sess, id).Error
if err != nil {
return 0, err
}
return sess.PromptCounter, nil
}
// GetSessionsToday returns the count of sessions started today.
func (s *SessionStore) GetSessionsToday(ctx context.Context) (int, error) {
// Get start of today in milliseconds
now := time.Now()
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
startEpoch := startOfDay.UnixMilli()
var count int64
err := s.db.WithContext(ctx).
Model(&SDKSession{}).
Where("started_at_epoch >= ?", startEpoch).
Count(&count).Error
return int(count), err
}
// GetAllProjects returns all unique project names.
func (s *SessionStore) GetAllProjects(ctx context.Context) ([]string, error) {
var projects []string
err := s.db.WithContext(ctx).
Model(&SDKSession{}).
Distinct("project").
Where("project IS NOT NULL AND project != ''").
Order("project ASC").
Pluck("project", &projects).Error
return projects, err
}
// toModelSDKSession converts a GORM SDKSession to pkg/models.SDKSession.
func toModelSDKSession(sess *SDKSession) *models.SDKSession {
return &models.SDKSession{
ID: sess.ID,
ClaudeSessionID: sess.ClaudeSessionID,
SDKSessionID: sess.SDKSessionID,
Project: sess.Project,
UserPrompt: sess.UserPrompt,
WorkerPort: sess.WorkerPort,
PromptCounter: int64(sess.PromptCounter),
Status: models.SessionStatus(sess.Status),
StartedAt: sess.StartedAt,
StartedAtEpoch: sess.StartedAtEpoch,
CompletedAt: sess.CompletedAt,
CompletedAtEpoch: sess.CompletedAtEpoch,
}
}