mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-09 23:59:40 +00:00
d04b60517a
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
220 lines
6.2 KiB
Go
220 lines
6.2 KiB
Go
// Package gorm provides GORM-based database operations for claude-mnemonic.
|
|
package gorm
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
|
)
|
|
|
|
// SessionStore provides session-related database operations using GORM.
|
|
type SessionStore struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewSessionStore creates a new session store.
|
|
func NewSessionStore(store *Store) *SessionStore {
|
|
return &SessionStore{db: store.DB}
|
|
}
|
|
|
|
// CreateSDKSession creates a new SDK session (idempotent - returns existing ID if exists).
|
|
// This is the KEY to how claude-mnemonic stays unified across hooks.
|
|
func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, project, userPrompt string) (int64, error) {
|
|
now := time.Now()
|
|
|
|
session := &SDKSession{
|
|
ClaudeSessionID: claudeSessionID,
|
|
SDKSessionID: func() sql.NullString {
|
|
return sql.NullString{String: claudeSessionID, Valid: true}
|
|
}(),
|
|
Project: project,
|
|
UserPrompt: func() sql.NullString {
|
|
if userPrompt != "" {
|
|
return sql.NullString{String: userPrompt, Valid: true}
|
|
}
|
|
return sql.NullString{Valid: false}
|
|
}(),
|
|
Status: "active",
|
|
StartedAt: now.Format(time.RFC3339),
|
|
StartedAtEpoch: now.UnixMilli(),
|
|
}
|
|
|
|
// CRITICAL: INSERT OR IGNORE makes this idempotent
|
|
// Use OnConflict with DoNothing to achieve INSERT OR IGNORE behavior
|
|
result := s.db.WithContext(ctx).
|
|
Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{{Name: "claude_session_id"}},
|
|
DoNothing: true,
|
|
}).
|
|
Create(session)
|
|
|
|
if result.Error != nil {
|
|
return 0, result.Error
|
|
}
|
|
|
|
// Check if insert happened
|
|
if result.RowsAffected == 0 {
|
|
// Session exists - UPDATE project and user_prompt if we have non-empty values
|
|
if project != "" {
|
|
updates := map[string]interface{}{
|
|
"project": project,
|
|
}
|
|
if userPrompt != "" {
|
|
updates["user_prompt"] = userPrompt
|
|
}
|
|
if err := s.db.WithContext(ctx).
|
|
Model(&SDKSession{}).
|
|
Where("claude_session_id = ?", claudeSessionID).
|
|
Updates(updates).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to update session: %w", err)
|
|
}
|
|
}
|
|
|
|
// Fetch existing session
|
|
var existing SDKSession
|
|
err := s.db.WithContext(ctx).
|
|
Where("claude_session_id = ?", claudeSessionID).
|
|
First(&existing).Error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return existing.ID, nil
|
|
}
|
|
|
|
return session.ID, nil
|
|
}
|
|
|
|
// GetSessionByID retrieves a session by its database ID.
|
|
func (s *SessionStore) GetSessionByID(ctx context.Context, id int64) (*models.SDKSession, error) {
|
|
var sess SDKSession
|
|
err := s.db.WithContext(ctx).First(&sess, id).Error
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toModelSDKSession(&sess), nil
|
|
}
|
|
|
|
// FindAnySDKSession finds any session by Claude session ID (any status).
|
|
func (s *SessionStore) FindAnySDKSession(ctx context.Context, claudeSessionID string) (*models.SDKSession, error) {
|
|
var sess SDKSession
|
|
err := s.db.WithContext(ctx).
|
|
Where("claude_session_id = ?", claudeSessionID).
|
|
First(&sess).Error
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return toModelSDKSession(&sess), nil
|
|
}
|
|
|
|
// IncrementPromptCounter increments the prompt counter and returns the new value.
|
|
// Uses a single SQL query with RETURNING clause for optimal performance.
|
|
func (s *SessionStore) IncrementPromptCounter(ctx context.Context, id int64) (int, error) {
|
|
// Use raw SQL with RETURNING to get updated value in single query
|
|
// SQLite supports RETURNING since version 3.35.0 (2021-03-12)
|
|
var newCounter int
|
|
err := s.db.WithContext(ctx).Raw(`
|
|
UPDATE sdk_sessions
|
|
SET prompt_counter = COALESCE(prompt_counter, 0) + 1
|
|
WHERE id = ?
|
|
RETURNING prompt_counter
|
|
`, id).Scan(&newCounter).Error
|
|
|
|
if err != nil {
|
|
// Fallback for older SQLite versions without RETURNING support
|
|
if err.Error() == "near \"RETURNING\": syntax error" || newCounter == 0 {
|
|
// Atomic increment
|
|
updateErr := s.db.WithContext(ctx).
|
|
Model(&SDKSession{}).
|
|
Where("id = ?", id).
|
|
Update("prompt_counter", gorm.Expr("COALESCE(prompt_counter, 0) + 1")).Error
|
|
if updateErr != nil {
|
|
return 0, updateErr
|
|
}
|
|
|
|
// Fetch updated value
|
|
var sess SDKSession
|
|
fetchErr := s.db.WithContext(ctx).
|
|
Select("prompt_counter").
|
|
First(&sess, id).Error
|
|
if fetchErr != nil {
|
|
return 0, fetchErr
|
|
}
|
|
return sess.PromptCounter, nil
|
|
}
|
|
return 0, err
|
|
}
|
|
|
|
return newCounter, nil
|
|
}
|
|
|
|
// GetPromptCounter returns the current prompt counter for a session.
|
|
func (s *SessionStore) GetPromptCounter(ctx context.Context, id int64) (int, error) {
|
|
var sess SDKSession
|
|
err := s.db.WithContext(ctx).
|
|
Select("prompt_counter").
|
|
First(&sess, id).Error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return sess.PromptCounter, nil
|
|
}
|
|
|
|
// GetSessionsToday returns the count of sessions started today.
|
|
func (s *SessionStore) GetSessionsToday(ctx context.Context) (int, error) {
|
|
// Get start of today in milliseconds
|
|
now := time.Now()
|
|
startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
|
startEpoch := startOfDay.UnixMilli()
|
|
|
|
var count int64
|
|
err := s.db.WithContext(ctx).
|
|
Model(&SDKSession{}).
|
|
Where("started_at_epoch >= ?", startEpoch).
|
|
Count(&count).Error
|
|
|
|
return int(count), err
|
|
}
|
|
|
|
// GetAllProjects returns all unique project names.
|
|
func (s *SessionStore) GetAllProjects(ctx context.Context) ([]string, error) {
|
|
var projects []string
|
|
err := s.db.WithContext(ctx).
|
|
Model(&SDKSession{}).
|
|
Distinct("project").
|
|
Where("project IS NOT NULL AND project != ''").
|
|
Order("project ASC").
|
|
Pluck("project", &projects).Error
|
|
|
|
return projects, err
|
|
}
|
|
|
|
// toModelSDKSession converts a GORM SDKSession to pkg/models.SDKSession.
|
|
func toModelSDKSession(sess *SDKSession) *models.SDKSession {
|
|
return &models.SDKSession{
|
|
ID: sess.ID,
|
|
ClaudeSessionID: sess.ClaudeSessionID,
|
|
SDKSessionID: sess.SDKSessionID,
|
|
Project: sess.Project,
|
|
UserPrompt: sess.UserPrompt,
|
|
WorkerPort: sess.WorkerPort,
|
|
PromptCounter: int64(sess.PromptCounter),
|
|
Status: models.SessionStatus(sess.Status),
|
|
StartedAt: sess.StartedAt,
|
|
StartedAtEpoch: sess.StartedAtEpoch,
|
|
CompletedAt: sess.CompletedAt,
|
|
CompletedAtEpoch: sess.CompletedAtEpoch,
|
|
}
|
|
}
|