mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-13 02:06:24 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,346 @@
|
||||
// Package session provides session lifecycle management for claude-mnemonic.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MessageType represents the type of pending message.
|
||||
type MessageType int
|
||||
|
||||
const (
|
||||
MessageTypeObservation MessageType = iota
|
||||
MessageTypeSummarize
|
||||
)
|
||||
|
||||
// ObservationData contains data for a tool observation.
|
||||
type ObservationData struct {
|
||||
ToolName string
|
||||
ToolInput interface{}
|
||||
ToolResponse interface{}
|
||||
PromptNumber int
|
||||
CWD string
|
||||
}
|
||||
|
||||
// SummarizeData contains data for a summarize request.
|
||||
type SummarizeData struct {
|
||||
LastUserMessage string
|
||||
LastAssistantMessage string
|
||||
}
|
||||
|
||||
// PendingMessage represents a message queued for SDK processing.
|
||||
type PendingMessage struct {
|
||||
Type MessageType
|
||||
Observation *ObservationData
|
||||
Summarize *SummarizeData
|
||||
}
|
||||
|
||||
// ActiveSession represents an in-memory active session being processed.
|
||||
type ActiveSession struct {
|
||||
SessionDBID int64
|
||||
ClaudeSessionID string
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
LastPromptNumber int
|
||||
StartTime time.Time
|
||||
CumulativeInputTokens int64
|
||||
CumulativeOutputTokens int64
|
||||
|
||||
// Concurrency control
|
||||
pendingMessages []PendingMessage
|
||||
messageMu sync.Mutex
|
||||
notify chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
generatorActive atomic.Bool
|
||||
}
|
||||
|
||||
// Manager manages active session lifecycles.
|
||||
type Manager struct {
|
||||
sessionStore *sqlite.SessionStore
|
||||
sessions map[int64]*ActiveSession
|
||||
mu sync.RWMutex
|
||||
onDeleted func(int64)
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager.
|
||||
func NewManager(sessionStore *sqlite.SessionStore) *Manager {
|
||||
return &Manager{
|
||||
sessionStore: sessionStore,
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnSessionDeleted sets a callback for when a session is deleted.
|
||||
func (m *Manager) SetOnSessionDeleted(callback func(int64)) {
|
||||
m.onDeleted = callback
|
||||
}
|
||||
|
||||
// InitializeSession initializes a session, creating it if needed.
|
||||
func (m *Manager) InitializeSession(ctx context.Context, sessionDBID int64, userPrompt string, promptNumber int) (*ActiveSession, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if already active
|
||||
if session, ok := m.sessions[sessionDBID]; ok {
|
||||
// Update user prompt for continuation
|
||||
if userPrompt != "" {
|
||||
session.UserPrompt = userPrompt
|
||||
session.LastPromptNumber = promptNumber
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Fetch from database
|
||||
dbSession, err := m.sessionStore.GetSessionByID(ctx, sessionDBID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dbSession == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Use provided userPrompt or fall back to database
|
||||
prompt := userPrompt
|
||||
if prompt == "" && dbSession.UserPrompt.Valid {
|
||||
prompt = dbSession.UserPrompt.String
|
||||
}
|
||||
|
||||
// Get prompt counter if not provided
|
||||
if promptNumber <= 0 {
|
||||
promptNumber, _ = m.sessionStore.GetPromptCounter(ctx, sessionDBID)
|
||||
}
|
||||
|
||||
// Create session context
|
||||
sessionCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
session := &ActiveSession{
|
||||
SessionDBID: sessionDBID,
|
||||
ClaudeSessionID: dbSession.ClaudeSessionID,
|
||||
SDKSessionID: dbSession.SDKSessionID.String,
|
||||
Project: dbSession.Project,
|
||||
UserPrompt: prompt,
|
||||
LastPromptNumber: promptNumber,
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0, 32),
|
||||
notify: make(chan struct{}, 1),
|
||||
ctx: sessionCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
m.sessions[sessionDBID] = session
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("project", session.Project).
|
||||
Str("claudeSessionId", session.ClaudeSessionID).
|
||||
Msg("Session initialized")
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// QueueObservation queues an observation for SDK processing.
|
||||
func (m *Manager) QueueObservation(ctx context.Context, sessionDBID int64, data ObservationData) error {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
// Auto-initialize from database
|
||||
m.mu.Unlock()
|
||||
var err error
|
||||
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
|
||||
if err != nil || session == nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
session.pendingMessages = append(session.pendingMessages, PendingMessage{
|
||||
Type: MessageTypeObservation,
|
||||
Observation: &data,
|
||||
})
|
||||
queueDepth := len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
|
||||
// Non-blocking notification
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("tool", data.ToolName).
|
||||
Int("queueDepth", queueDepth).
|
||||
Msg("Observation queued")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueSummarize queues a summarize request for SDK processing.
|
||||
func (m *Manager) QueueSummarize(ctx context.Context, sessionDBID int64, lastUserMessage, lastAssistantMessage string) error {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
// Auto-initialize from database
|
||||
m.mu.Unlock()
|
||||
var err error
|
||||
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
|
||||
if err != nil || session == nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
session.pendingMessages = append(session.pendingMessages, PendingMessage{
|
||||
Type: MessageTypeSummarize,
|
||||
Summarize: &SummarizeData{
|
||||
LastUserMessage: lastUserMessage,
|
||||
LastAssistantMessage: lastAssistantMessage,
|
||||
},
|
||||
})
|
||||
queueDepth := len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
|
||||
// Non-blocking notification
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Int("queueDepth", queueDepth).
|
||||
Msg("Summarize request queued")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSession removes a session and cleans up resources.
|
||||
func (m *Manager) DeleteSession(sessionDBID int64) {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m.sessions, sessionDBID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// Cancel context to stop generator
|
||||
session.cancel()
|
||||
|
||||
duration := time.Since(session.StartTime)
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("project", session.Project).
|
||||
Dur("duration", duration).
|
||||
Msg("Session deleted")
|
||||
|
||||
// Trigger callback
|
||||
if m.onDeleted != nil {
|
||||
m.onDeleted(sessionDBID)
|
||||
}
|
||||
}
|
||||
|
||||
// ShutdownAll shuts down all active sessions.
|
||||
func (m *Manager) ShutdownAll(ctx context.Context) {
|
||||
m.mu.Lock()
|
||||
sessionIDs := make([]int64, 0, len(m.sessions))
|
||||
for id := range m.sessions {
|
||||
sessionIDs = append(sessionIDs, id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, id := range sessionIDs {
|
||||
m.DeleteSession(id)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("count", len(sessionIDs)).
|
||||
Msg("All sessions shut down")
|
||||
}
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions.
|
||||
func (m *Manager) GetActiveSessionCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// GetTotalQueueDepth returns the total queue depth across all sessions.
|
||||
func (m *Manager) GetTotalQueueDepth() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
total := 0
|
||||
for _, session := range m.sessions {
|
||||
session.messageMu.Lock()
|
||||
total += len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// IsAnySessionProcessing returns true if any session is actively processing.
|
||||
func (m *Manager) IsAnySessionProcessing() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, session := range m.sessions {
|
||||
// Check for pending messages
|
||||
session.messageMu.Lock()
|
||||
hasPending := len(session.pendingMessages) > 0
|
||||
session.messageMu.Unlock()
|
||||
if hasPending {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for active generator
|
||||
if session.generatorActive.Load() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAllSessions returns a copy of all active sessions.
|
||||
func (m *Manager) GetAllSessions() []*ActiveSession {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sessions := make([]*ActiveSession, 0, len(m.sessions))
|
||||
for _, session := range m.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
// DrainMessages drains and returns all pending messages for a session.
|
||||
func (m *Manager) DrainMessages(sessionDBID int64) []PendingMessage {
|
||||
m.mu.RLock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
messages := make([]PendingMessage, len(session.pendingMessages))
|
||||
copy(messages, session.pendingMessages)
|
||||
session.pendingMessages = session.pendingMessages[:0]
|
||||
session.messageMu.Unlock()
|
||||
|
||||
return messages
|
||||
}
|
||||
Reference in New Issue
Block a user