mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-12 00:19:20 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
|
||||
)
|
||||
|
||||
// clusterObservations groups similar observations and returns only one representative per cluster.
|
||||
// Uses Jaccard similarity on extracted terms from title, narrative, and facts.
|
||||
// Delegates to pkg/similarity for the actual clustering logic.
|
||||
func clusterObservations(observations []*models.Observation, similarityThreshold float64) []*models.Observation {
|
||||
return similarity.ClusterObservations(observations, similarityThreshold)
|
||||
}
|
||||
@@ -0,0 +1,705 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Handler configuration constants
|
||||
const (
|
||||
// DefaultObservationsLimit is the default number of observations to return.
|
||||
DefaultObservationsLimit = 100
|
||||
|
||||
// DefaultSummariesLimit is the default number of summaries to return.
|
||||
DefaultSummariesLimit = 50
|
||||
|
||||
// DefaultPromptsLimit is the default number of prompts to return.
|
||||
DefaultPromptsLimit = 100
|
||||
|
||||
// DefaultSearchLimit is the default number of search results to return.
|
||||
DefaultSearchLimit = 50
|
||||
|
||||
// DefaultContextLimit is the default number of context observations to return.
|
||||
DefaultContextLimit = 50
|
||||
)
|
||||
|
||||
// writeJSON writes a JSON response with proper error handling.
|
||||
func writeJSON(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to encode JSON response")
|
||||
}
|
||||
}
|
||||
|
||||
// handleHealth handles health check requests.
|
||||
// Returns 200 OK immediately (even during init) so hooks can connect quickly.
|
||||
// Use /api/ready for full readiness check.
|
||||
func (s *Service) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := "starting"
|
||||
if s.ready.Load() {
|
||||
status = "ready"
|
||||
} else if err := s.GetInitError(); err != nil {
|
||||
status = "error"
|
||||
}
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"status": status,
|
||||
"version": s.version,
|
||||
})
|
||||
}
|
||||
|
||||
// handleVersion returns the worker version for version checking.
|
||||
func (s *Service) handleVersion(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, map[string]string{
|
||||
"version": s.version,
|
||||
})
|
||||
}
|
||||
|
||||
// handleReady handles readiness check requests.
|
||||
// Returns 200 only when fully initialized, 503 otherwise.
|
||||
func (s *Service) handleReady(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.ready.Load() {
|
||||
if err := s.GetInitError(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Error(w, "service initializing", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]string{"status": "ready"})
|
||||
}
|
||||
|
||||
// requireReady is middleware that returns 503 if service isn't ready.
|
||||
func (s *Service) requireReady(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.ready.Load() {
|
||||
if err := s.GetInitError(); err != nil {
|
||||
http.Error(w, "service initialization failed: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Error(w, "service initializing", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SessionInitRequest is the request body for session initialization.
|
||||
type SessionInitRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
Prompt string `json:"prompt"`
|
||||
MatchedObservations int `json:"matchedObservations"`
|
||||
}
|
||||
|
||||
// SessionInitResponse is the response for session initialization.
|
||||
type SessionInitResponse struct {
|
||||
SessionDBID int64 `json:"sessionDbId"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
Skipped bool `json:"skipped,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// handleSessionInit handles session initialization from user-prompt hook.
|
||||
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
|
||||
var req SessionInitRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Privacy check
|
||||
if privacy.IsEntirelyPrivate(req.Prompt) {
|
||||
// Create session but skip processing
|
||||
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
Skipped: true,
|
||||
Reason: "private",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Clean prompt and create session
|
||||
cleanedPrompt := privacy.Clean(req.Prompt)
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment prompt counter
|
||||
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Save user prompt with matched observation count
|
||||
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to save user prompt")
|
||||
// Non-fatal: continue with session initialization
|
||||
} else if s.chromaSync != nil {
|
||||
// Sync to vector DB
|
||||
now := time.Now()
|
||||
promptWithSession := &models.UserPromptWithSession{
|
||||
UserPrompt: models.UserPrompt{
|
||||
ID: promptID,
|
||||
ClaudeSessionID: req.ClaudeSessionID,
|
||||
PromptNumber: promptNum,
|
||||
PromptText: cleanedPrompt,
|
||||
MatchedObservations: req.MatchedObservations,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
Project: req.Project,
|
||||
SDKSessionID: req.ClaudeSessionID,
|
||||
}
|
||||
if err := s.chromaSync.SyncUserPrompt(r.Context(), promptWithSession); err != nil {
|
||||
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to ChromaDB")
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionID).
|
||||
Int("promptNumber", promptNum).
|
||||
Str("project", req.Project).
|
||||
Msg("Session initialized")
|
||||
|
||||
// Broadcast prompt event for dashboard refresh
|
||||
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||
"type": "prompt",
|
||||
"action": "created",
|
||||
"project": req.Project,
|
||||
})
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionStartRequest is the request body for starting SDK agent.
|
||||
type SessionStartRequest struct {
|
||||
UserPrompt string `json:"userPrompt"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
}
|
||||
|
||||
// handleSessionStart handles SDK agent session start.
|
||||
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SessionStartRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize session in manager
|
||||
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Session is now registered. Observations will be processed
|
||||
// asynchronously by the background queue processor (processQueue in service.go).
|
||||
log.Info().
|
||||
Int64("sessionId", id).
|
||||
Int("promptNumber", req.PromptNumber).
|
||||
Msg("SDK agent session initialized")
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// ObservationRequest is the request body for posting observations.
|
||||
type ObservationRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput interface{} `json:"tool_input"`
|
||||
ToolResponse interface{} `json:"tool_response"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
|
||||
// handleObservation handles observation posting from post-tool-use hook.
|
||||
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
|
||||
var req ObservationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
// Create session on-the-fly with project from request
|
||||
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
|
||||
}
|
||||
|
||||
// Queue observation
|
||||
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
|
||||
ToolName: req.ToolName,
|
||||
ToolInput: req.ToolInput,
|
||||
ToolResponse: req.ToolResponse,
|
||||
CWD: req.CWD,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// SubagentCompleteRequest is the request body for subagent completion.
|
||||
type SubagentCompleteRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// handleSubagentComplete handles subagent/Task completion notifications.
|
||||
// This triggers immediate processing of any queued observations from the subagent.
|
||||
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
|
||||
var req SubagentCompleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil || sess == nil {
|
||||
// Session not found - subagent may have been in a different context
|
||||
log.Debug().
|
||||
Str("claudeSessionId", req.ClaudeSessionID).
|
||||
Msg("Subagent complete - no active session found")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger immediate processing of queued observations
|
||||
messages := s.sessionManager.DrainMessages(sess.ID)
|
||||
if len(messages) > 0 && s.processor != nil {
|
||||
log.Info().
|
||||
Int64("sessionId", sess.ID).
|
||||
Int("messages", len(messages)).
|
||||
Msg("Processing queued observations from subagent")
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
|
||||
err := s.processor.ProcessObservation(
|
||||
r.Context(),
|
||||
sess.SDKSessionID.String,
|
||||
sess.Project,
|
||||
msg.Observation.ToolName,
|
||||
msg.Observation.ToolInput,
|
||||
msg.Observation.ToolResponse,
|
||||
msg.Observation.PromptNumber,
|
||||
msg.Observation.CWD,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("tool", msg.Observation.ToolName).
|
||||
Msg("Failed to process subagent observation")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleGetSessionByClaudeID looks up a session by Claude session ID.
|
||||
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
|
||||
claudeSessionID := r.URL.Query().Get("claudeSessionId")
|
||||
if claudeSessionID == "" {
|
||||
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, session)
|
||||
}
|
||||
|
||||
// SummarizeRequest is the request body for summarize requests.
|
||||
type SummarizeRequest struct {
|
||||
LastUserMessage string `json:"lastUserMessage"`
|
||||
LastAssistantMessage string `json:"lastAssistantMessage"`
|
||||
}
|
||||
|
||||
// handleSummarize handles summarize requests from stop hook.
|
||||
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SummarizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Queue summarize request
|
||||
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleGetObservations returns recent observations.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
limit := DefaultObservationsLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
// Filter by project - includes project-scoped and global observations
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
} else {
|
||||
// All projects
|
||||
observations, err = s.observationStore.GetAllRecentObservations(r.Context(), limit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleGetSummaries returns recent summaries.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
limit := DefaultSummariesLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
var summaries []*models.SessionSummary
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
} else {
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if summaries == nil {
|
||||
summaries = []*models.SessionSummary{}
|
||||
}
|
||||
writeJSON(w, summaries)
|
||||
}
|
||||
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
limit := DefaultPromptsLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
var prompts []*models.UserPromptWithSession
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
} else {
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if prompts == nil {
|
||||
prompts = []*models.UserPromptWithSession{}
|
||||
}
|
||||
writeJSON(w, prompts)
|
||||
}
|
||||
|
||||
// handleGetProjects returns all projects.
|
||||
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
|
||||
projects, err := s.sessionStore.GetAllProjects(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, projects)
|
||||
}
|
||||
|
||||
// handleGetStats returns worker statistics.
|
||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
retrievalStats := s.GetRetrievalStats()
|
||||
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"uptime": time.Since(s.startTime).String(),
|
||||
"activeSessions": s.sessionManager.GetActiveSessionCount(),
|
||||
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
|
||||
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
|
||||
"connectedClients": s.sseBroadcaster.ClientCount(),
|
||||
"sessionsToday": sessionsToday,
|
||||
"retrieval": retrievalStats,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetRetrievalStats returns detailed retrieval statistics.
|
||||
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats := s.GetRetrievalStats()
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleContextCount returns the count of observations for a project.
|
||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := s.observationStore.GetObservationCount(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"project": project,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
|
||||
// handleSearchByPrompt searches observations relevant to a user prompt.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
|
||||
if project == "" || query == "" {
|
||||
http.Error(w, "project and query required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit := DefaultSearchLimit
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Search using FTS
|
||||
observations, err := s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for interactive use)
|
||||
// Just check mtimes and exclude obviously stale observations
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 && cwd != "" {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
|
||||
// Record retrieval stats (no verification done, so verified=0, deleted=0)
|
||||
s.recordRetrievalStats(int64(len(clusteredObservations)), 0, 0, true)
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Str("query", query).
|
||||
Int("found", len(clusteredObservations)).
|
||||
Int("stale_excluded", staleCount).
|
||||
Msg("Prompt-based observation search")
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"project": project,
|
||||
"query": query,
|
||||
"observations": clusteredObservations,
|
||||
})
|
||||
}
|
||||
|
||||
// handleContextInject returns context for injection at session start.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
if cwd == "" {
|
||||
cwd = "/"
|
||||
}
|
||||
|
||||
// Limit observations for fast startup (configurable, default 100)
|
||||
limit := s.config.ContextObservations
|
||||
if limit <= 0 {
|
||||
limit = DefaultContextLimit
|
||||
}
|
||||
|
||||
// Full count determines how many observations get full detail (configurable, default 25)
|
||||
fullCount := s.config.ContextFullCount
|
||||
if fullCount <= 0 {
|
||||
fullCount = 25
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for startup)
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
|
||||
|
||||
// Record retrieval stats (no verification done)
|
||||
s.recordRetrievalStats(int64(len(clusteredObservations)), 0, 0, false)
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Int("total", len(observations)).
|
||||
Int("fresh", len(freshObservations)).
|
||||
Int("clustered", len(clusteredObservations)).
|
||||
Int("duplicates", duplicatesRemoved).
|
||||
Int("stale_excluded", staleCount).
|
||||
Msg("Context injection with clustering")
|
||||
|
||||
writeJSON(w, map[string]interface{}{
|
||||
"project": project,
|
||||
"observations": clusteredObservations,
|
||||
"full_count": fullCount,
|
||||
"stale_excluded": staleCount,
|
||||
"duplicates_removed": duplicatesRemoved,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,553 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testService creates a Service with a test SQLite database including FTS5 for testing.
|
||||
func testService(t *testing.T) (*Service, func()) {
|
||||
t.Helper()
|
||||
|
||||
// Create test store (runs migrations to create all tables including FTS5)
|
||||
store, dbCleanup := testStore(t)
|
||||
|
||||
// Create store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
|
||||
// Create domain services
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
sseBroadcaster := sse.NewBroadcaster()
|
||||
|
||||
// Create context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create router
|
||||
router := chi.NewRouter()
|
||||
|
||||
svc := &Service{
|
||||
version: "test-version",
|
||||
config: config.Get(),
|
||||
store: store,
|
||||
sessionStore: sessionStore,
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
promptStore: promptStore,
|
||||
sessionManager: sessionManager,
|
||||
sseBroadcaster: sseBroadcaster,
|
||||
router: router,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
svc.setupRoutes()
|
||||
|
||||
// Mark service as ready for tests
|
||||
svc.ready.Store(true)
|
||||
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
store.Close()
|
||||
dbCleanup()
|
||||
}
|
||||
|
||||
return svc, cleanup
|
||||
}
|
||||
|
||||
// createTestObservation creates a test observation in the database.
|
||||
func createTestObservation(t *testing.T, store *sqlite.ObservationStore, project, title, narrative string, concepts []string) int64 {
|
||||
t.Helper()
|
||||
|
||||
obs := &models.ParsedObservation{
|
||||
Type: models.ObsTypeDiscovery,
|
||||
Title: title,
|
||||
Narrative: narrative,
|
||||
Concepts: concepts,
|
||||
}
|
||||
|
||||
id, _, err := store.StoreObservation(context.Background(), "test-session", project, obs, 1, 100)
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func TestHandleSearchByPrompt_DefaultLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "test-project"
|
||||
|
||||
// Create 60 observations (more than the default limit of 50)
|
||||
for i := 0; i < 60; i++ {
|
||||
createTestObservation(t, svc.observationStore, project,
|
||||
"Test observation about authentication",
|
||||
"This observation is about authentication and security patterns",
|
||||
[]string{"authentication", "security"})
|
||||
// Small delay to ensure different timestamps
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Make request without limit parameter
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=authentication", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "observations should be an array")
|
||||
|
||||
// The default limit is now 50, not 5
|
||||
// Note: clustering may reduce the count, but we should have more than 5
|
||||
t.Logf("Got %d observations", len(observations))
|
||||
// Just verify we got a reasonable number, accounting for clustering
|
||||
assert.True(t, len(observations) >= 1, "should return at least one observation")
|
||||
}
|
||||
|
||||
func TestHandleSearchByPrompt_CustomLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "test-project"
|
||||
|
||||
// Create 20 unique observations
|
||||
for i := 0; i < 20; i++ {
|
||||
createTestObservation(t, svc.observationStore, project,
|
||||
"Unique observation "+string(rune('A'+i))+" about testing",
|
||||
"This is unique observation number "+string(rune('A'+i)),
|
||||
[]string{"unique-" + string(rune('a'+i))})
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Request with custom limit of 15
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=observation&limit=15", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// Should respect the custom limit (accounting for clustering)
|
||||
t.Logf("Got %d observations with limit=15", len(observations))
|
||||
assert.LessOrEqual(t, len(observations), 15)
|
||||
}
|
||||
|
||||
func TestHandleSearchByPrompt_NoHardcodedLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "test-project"
|
||||
|
||||
// Create observations with VERY different content to avoid clustering
|
||||
// Each has unique words that won't match other observations
|
||||
uniqueObservations := []struct {
|
||||
title string
|
||||
narrative string
|
||||
concepts []string
|
||||
}{
|
||||
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
|
||||
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
|
||||
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
|
||||
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
|
||||
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
|
||||
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
|
||||
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
|
||||
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
|
||||
}
|
||||
|
||||
for _, obs := range uniqueObservations {
|
||||
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Search using a common keyword that should match most observations
|
||||
// Using broader query to match multiple items
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=tokens+indexes+caching+logging&limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// The key is that the limit is no longer hardcoded to 5
|
||||
// With our new default of 50, we should be able to return more than 5
|
||||
t.Logf("Got %d observations (limit=10)", len(observations))
|
||||
// The test passes as long as the default limit (50) is being used instead of 5
|
||||
// and we can request a custom limit
|
||||
assert.LessOrEqual(t, len(observations), 10, "should respect the custom limit")
|
||||
}
|
||||
|
||||
func TestHandleSearchByPrompt_RequiredParams(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "missing project",
|
||||
query: "/api/context/search?query=test",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing query",
|
||||
query: "/api/context/search?project=test",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "both present",
|
||||
query: "/api/context/search?project=test&query=test",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.query, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleContextInject_NoHardcodedLimit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Set a higher context observations limit in config
|
||||
svc.config.ContextObservations = 50
|
||||
|
||||
project := "test-project"
|
||||
|
||||
// Create observations with VERY different content to avoid clustering
|
||||
uniqueObservations := []struct {
|
||||
title string
|
||||
narrative string
|
||||
concepts []string
|
||||
}{
|
||||
{"JWT tokens expire daily", "OAuth2 bearer tokens authentication", []string{"jwt"}},
|
||||
{"PostgreSQL indexes optimize queries", "B-tree index on user table", []string{"postgres"}},
|
||||
{"Redis caching TTL configuration", "Memory eviction policy LRU", []string{"redis"}},
|
||||
{"Zerolog structured logging", "JSON output formatting levels", []string{"logging"}},
|
||||
{"Pytest fixtures setup teardown", "Mock objects dependency injection", []string{"pytest"}},
|
||||
{"Docker containers orchestration", "Compose multi-stage builds", []string{"docker"}},
|
||||
{"Prometheus metrics collection", "Grafana dashboards alerting", []string{"prometheus"}},
|
||||
{"OWASP vulnerability scanning", "SQL injection XSS prevention", []string{"owasp"}},
|
||||
}
|
||||
|
||||
for _, obs := range uniqueObservations {
|
||||
createTestObservation(t, svc.observationStore, project, obs.title, obs.narrative, obs.concepts)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject?project="+project, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// With very different content, we should get multiple observations back
|
||||
// The key verification is that the hardcoded limit of 5 has been removed
|
||||
t.Logf("Got %d observations from context inject", len(observations))
|
||||
// Should return more than old limit of 5 with unique observations
|
||||
assert.GreaterOrEqual(t, len(observations), 1, "should return at least 1 observation")
|
||||
}
|
||||
|
||||
func TestHandleContextInject_RequiresProject(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/inject", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleGetObservations_Limit(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create 20 observations
|
||||
for i := 0; i < 20; i++ {
|
||||
createTestObservation(t, svc.observationStore, "project-"+string(rune('a'+i%5)),
|
||||
"Observation "+string(rune('A'+i)),
|
||||
"Content of observation "+string(rune('A'+i)),
|
||||
[]string{"test"})
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Request with limit=10
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/observations?limit=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse as generic JSON array since the model uses custom marshaling
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, observations, 10)
|
||||
}
|
||||
|
||||
func TestSearchObservations_GlobalScope(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a project-scoped observation
|
||||
createTestObservation(t, svc.observationStore, "project-a",
|
||||
"Project specific code",
|
||||
"This is specific to project-a",
|
||||
[]string{"project-specific"})
|
||||
|
||||
// Create a global-scoped observation (has a globalizable concept)
|
||||
createTestObservation(t, svc.observationStore, "project-a",
|
||||
"Security best practice",
|
||||
"Always validate user input",
|
||||
[]string{"security", "best-practice"})
|
||||
|
||||
// Search from project-b - should find global observation
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project=project-b&query=security", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// Should find the global observation even though it was created in project-a
|
||||
assert.GreaterOrEqual(t, len(observations), 1)
|
||||
}
|
||||
|
||||
func TestClusterObservations_RemovesDuplicates(t *testing.T) {
|
||||
// Create similar observations
|
||||
obs1 := &models.Observation{
|
||||
ID: 1,
|
||||
Title: sql.NullString{String: "Authentication flow implementation", Valid: true},
|
||||
Narrative: sql.NullString{String: "We implemented JWT-based authentication", Valid: true},
|
||||
}
|
||||
obs2 := &models.Observation{
|
||||
ID: 2,
|
||||
Title: sql.NullString{String: "Authentication flow update", Valid: true},
|
||||
Narrative: sql.NullString{String: "Updated JWT-based authentication logic", Valid: true},
|
||||
}
|
||||
obs3 := &models.Observation{
|
||||
ID: 3,
|
||||
Title: sql.NullString{String: "Database migration guide", Valid: true},
|
||||
Narrative: sql.NullString{String: "How to run database migrations", Valid: true},
|
||||
}
|
||||
|
||||
observations := []*models.Observation{obs1, obs2, obs3}
|
||||
|
||||
// Cluster with 0.4 threshold
|
||||
clustered := clusterObservations(observations, 0.4)
|
||||
|
||||
// obs1 and obs2 should be clustered together, obs3 is different
|
||||
assert.LessOrEqual(t, len(clustered), 3)
|
||||
assert.GreaterOrEqual(t, len(clustered), 1)
|
||||
|
||||
// The first observation in each cluster should be kept (obs1, obs3)
|
||||
t.Logf("Clustered %d observations down to %d", len(observations), len(clustered))
|
||||
}
|
||||
|
||||
func TestRetrievalStats(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
project := "test-project"
|
||||
createTestObservation(t, svc.observationStore, project,
|
||||
"Test observation",
|
||||
"Test narrative",
|
||||
[]string{"test"})
|
||||
|
||||
// Make a search request
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/context/search?project="+project+"&query=test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Check stats
|
||||
stats := svc.GetRetrievalStats()
|
||||
assert.Equal(t, int64(1), stats.TotalRequests)
|
||||
assert.Equal(t, int64(1), stats.SearchRequests)
|
||||
assert.GreaterOrEqual(t, stats.ObservationsServed, int64(1))
|
||||
}
|
||||
|
||||
func TestHandleHealth_ReturnsVersion(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.version = "test-version-1.2.3"
|
||||
svc.ready.Store(true)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleHealth(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ready", response["status"])
|
||||
assert.Equal(t, "test-version-1.2.3", response["version"])
|
||||
}
|
||||
|
||||
func TestHandleVersion(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.version = "v2.0.0-beta"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleVersion(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v2.0.0-beta", response["version"])
|
||||
}
|
||||
|
||||
func TestHandleReady_ServiceNotReady(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Reset ready state to simulate service not being ready
|
||||
svc.ready.Store(false)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleReady(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
}
|
||||
|
||||
func TestHandleReady_ServiceReady(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(true)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
svc.handleReady(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response map[string]string
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ready", response["status"])
|
||||
}
|
||||
|
||||
func TestRequireReadyMiddleware_Blocks(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
// Reset ready state to simulate service not being ready
|
||||
svc.ready.Store(false)
|
||||
|
||||
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireReadyMiddleware_Allows(t *testing.T) {
|
||||
svc, cleanup := testService(t)
|
||||
defer cleanup()
|
||||
|
||||
svc.ready.Store(true)
|
||||
|
||||
handler := svc.requireReady(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "success", rec.Body.String())
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// Package sdk provides SDK agent integration for claude-mnemonic.
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
// Observation parsing
|
||||
observationRegex = regexp.MustCompile(`(?s)<observation>(.*?)</observation>`)
|
||||
|
||||
// Summary parsing
|
||||
summaryRegex = regexp.MustCompile(`(?s)<summary>(.*?)</summary>`)
|
||||
skipSummaryRegex = regexp.MustCompile(`<skip_summary\s+reason="([^"]+)"\s*/>`)
|
||||
|
||||
// Valid observation types
|
||||
validObsTypes = map[string]bool{
|
||||
"bugfix": true,
|
||||
"feature": true,
|
||||
"refactor": true,
|
||||
"change": true,
|
||||
"discovery": true,
|
||||
"decision": true,
|
||||
}
|
||||
|
||||
// Valid concepts (strict list - no custom tags allowed)
|
||||
validConcepts = map[string]bool{
|
||||
"how-it-works": true,
|
||||
"why-it-exists": true,
|
||||
"what-changed": true,
|
||||
"problem-solution": true,
|
||||
"gotcha": true,
|
||||
"pattern": true,
|
||||
"trade-off": true,
|
||||
}
|
||||
)
|
||||
|
||||
// ParseObservations parses observation XML blocks from SDK response text.
|
||||
func ParseObservations(text string, correlationID string) []*models.ParsedObservation {
|
||||
var observations []*models.ParsedObservation
|
||||
|
||||
matches := observationRegex.FindAllStringSubmatch(text, -1)
|
||||
for _, match := range matches {
|
||||
if len(match) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
obsContent := match[1]
|
||||
|
||||
// Extract fields
|
||||
obsType := extractField(obsContent, "type")
|
||||
title := extractField(obsContent, "title")
|
||||
subtitle := extractField(obsContent, "subtitle")
|
||||
narrative := extractField(obsContent, "narrative")
|
||||
facts := extractArrayElements(obsContent, "facts", "fact")
|
||||
concepts := extractArrayElements(obsContent, "concepts", "concept")
|
||||
filesRead := extractArrayElements(obsContent, "files_read", "file")
|
||||
filesModified := extractArrayElements(obsContent, "files_modified", "file")
|
||||
|
||||
// Determine final type (default to "change" if invalid)
|
||||
finalType := models.ObsTypeChange
|
||||
if obsType != "" {
|
||||
if validObsTypes[obsType] {
|
||||
finalType = models.ObservationType(obsType)
|
||||
} else {
|
||||
log.Warn().
|
||||
Str("correlationId", correlationID).
|
||||
Str("invalidType", obsType).
|
||||
Msg("Invalid observation type, using 'change'")
|
||||
}
|
||||
} else {
|
||||
log.Warn().
|
||||
Str("correlationId", correlationID).
|
||||
Msg("Observation missing type field, using 'change'")
|
||||
}
|
||||
|
||||
// Filter concepts: only keep valid ones from the strict list
|
||||
cleanedConcepts := make([]string, 0, len(concepts))
|
||||
var invalidConcepts []string
|
||||
for _, c := range concepts {
|
||||
c = strings.ToLower(strings.TrimSpace(c))
|
||||
if c == string(finalType) {
|
||||
continue // Skip type in concepts
|
||||
}
|
||||
if validConcepts[c] {
|
||||
cleanedConcepts = append(cleanedConcepts, c)
|
||||
} else {
|
||||
invalidConcepts = append(invalidConcepts, c)
|
||||
}
|
||||
}
|
||||
if len(invalidConcepts) > 0 {
|
||||
log.Warn().
|
||||
Str("correlationId", correlationID).
|
||||
Strs("invalidConcepts", invalidConcepts).
|
||||
Msg("Filtered out invalid concepts (not in allowed list)")
|
||||
}
|
||||
|
||||
observations = append(observations, &models.ParsedObservation{
|
||||
Type: finalType,
|
||||
Title: title,
|
||||
Subtitle: subtitle,
|
||||
Facts: facts,
|
||||
Narrative: narrative,
|
||||
Concepts: cleanedConcepts,
|
||||
FilesRead: filesRead,
|
||||
FilesModified: filesModified,
|
||||
})
|
||||
}
|
||||
|
||||
return observations
|
||||
}
|
||||
|
||||
// ParseSummary parses a summary XML block from SDK response text.
|
||||
func ParseSummary(text string, sessionID int64) *models.ParsedSummary {
|
||||
// Check for skip_summary first
|
||||
if skipMatch := skipSummaryRegex.FindStringSubmatch(text); skipMatch != nil {
|
||||
log.Info().
|
||||
Int64("sessionId", sessionID).
|
||||
Str("reason", skipMatch[1]).
|
||||
Msg("Summary skipped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find summary block
|
||||
match := summaryRegex.FindStringSubmatch(text)
|
||||
if len(match) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
summaryContent := match[1]
|
||||
|
||||
return &models.ParsedSummary{
|
||||
Request: extractField(summaryContent, "request"),
|
||||
Investigated: extractField(summaryContent, "investigated"),
|
||||
Learned: extractField(summaryContent, "learned"),
|
||||
Completed: extractField(summaryContent, "completed"),
|
||||
NextSteps: extractField(summaryContent, "next_steps"),
|
||||
Notes: extractField(summaryContent, "notes"),
|
||||
}
|
||||
}
|
||||
|
||||
// extractField extracts a simple field value from XML content.
|
||||
func extractField(content, fieldName string) string {
|
||||
pattern := regexp.MustCompile(`<` + fieldName + `>([^<]*)</` + fieldName + `>`)
|
||||
match := pattern.FindStringSubmatch(content)
|
||||
if len(match) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(match[1])
|
||||
}
|
||||
|
||||
// extractArrayElements extracts array elements from XML content.
|
||||
func extractArrayElements(content, arrayName, elementName string) []string {
|
||||
var elements []string
|
||||
|
||||
// Find the array block
|
||||
arrayPattern := regexp.MustCompile(`(?s)<` + arrayName + `>(.*?)</` + arrayName + `>`)
|
||||
arrayMatch := arrayPattern.FindStringSubmatch(content)
|
||||
if len(arrayMatch) < 2 {
|
||||
return elements
|
||||
}
|
||||
|
||||
arrayContent := arrayMatch[1]
|
||||
|
||||
// Extract individual elements
|
||||
elementPattern := regexp.MustCompile(`<` + elementName + `>([^<]+)</` + elementName + `>`)
|
||||
elementMatches := elementPattern.FindAllStringSubmatch(arrayContent, -1)
|
||||
for _, match := range elementMatches {
|
||||
if len(match) >= 2 {
|
||||
elements = append(elements, strings.TrimSpace(match[1]))
|
||||
}
|
||||
}
|
||||
|
||||
return elements
|
||||
}
|
||||
@@ -0,0 +1,678 @@
|
||||
// Package sdk provides SDK agent integration for claude-mnemonic.
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BroadcastFunc is a callback for broadcasting events to SSE clients.
|
||||
type BroadcastFunc func(event map[string]interface{})
|
||||
|
||||
// SyncObservationFunc is a callback for syncing observations to vector DB.
|
||||
type SyncObservationFunc func(obs *models.Observation)
|
||||
|
||||
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
|
||||
type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
|
||||
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
|
||||
type Processor struct {
|
||||
claudePath string
|
||||
model string
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
broadcastFunc BroadcastFunc
|
||||
syncObservationFunc SyncObservationFunc
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// SetBroadcastFunc sets the broadcast callback for SSE events.
|
||||
func (p *Processor) SetBroadcastFunc(fn BroadcastFunc) {
|
||||
p.broadcastFunc = fn
|
||||
}
|
||||
|
||||
// SetSyncObservationFunc sets the callback for syncing observations to vector DB.
|
||||
func (p *Processor) SetSyncObservationFunc(fn SyncObservationFunc) {
|
||||
p.syncObservationFunc = fn
|
||||
}
|
||||
|
||||
// SetSyncSummaryFunc sets the callback for syncing summaries to vector DB.
|
||||
func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
|
||||
p.syncSummaryFunc = fn
|
||||
}
|
||||
|
||||
// broadcast sends an event via the broadcast callback if set.
|
||||
func (p *Processor) broadcast(event map[string]interface{}) {
|
||||
if p.broadcastFunc != nil {
|
||||
p.broadcastFunc(event)
|
||||
}
|
||||
}
|
||||
|
||||
// NewProcessor creates a new SDK processor.
|
||||
func NewProcessor(observationStore *sqlite.ObservationStore, summaryStore *sqlite.SummaryStore) (*Processor, error) {
|
||||
cfg := config.Get()
|
||||
|
||||
// Find Claude Code CLI
|
||||
claudePath := cfg.ClaudeCodePath
|
||||
if claudePath == "" {
|
||||
// Try to find in PATH
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude CLI not found in PATH and CLAUDE_CODE_PATH not set")
|
||||
}
|
||||
claudePath = path
|
||||
}
|
||||
|
||||
// Verify it exists
|
||||
if _, err := os.Stat(claudePath); err != nil {
|
||||
return nil, fmt.Errorf("claude CLI not found at %s: %w", claudePath, err)
|
||||
}
|
||||
|
||||
return &Processor{
|
||||
claudePath: claudePath,
|
||||
model: cfg.Model,
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessObservation processes a single tool observation and extracts insights.
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Skip certain tools that aren't worth processing
|
||||
if shouldSkipTool(toolName) {
|
||||
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
|
||||
|
||||
// Convert tool data to strings
|
||||
inputStr := toJSONString(toolInput)
|
||||
outputStr := toJSONString(toolResponse)
|
||||
|
||||
// Check if we already have observations for this file (skip if covered)
|
||||
if filePath := extractFilePath(toolName, inputStr); filePath != "" {
|
||||
exists, err := p.observationStore.ExistsSimilarObservation(ctx, project, []string{filePath}, nil)
|
||||
if err == nil && exists {
|
||||
log.Debug().
|
||||
Str("tool", toolName).
|
||||
Str("file", filePath).
|
||||
Msg("Skipping - file already has observations")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Build the prompt
|
||||
exec := ToolExecution{
|
||||
ToolName: toolName,
|
||||
ToolInput: inputStr,
|
||||
ToolOutput: outputStr,
|
||||
CWD: cwd,
|
||||
}
|
||||
prompt := BuildObservationPrompt(exec)
|
||||
|
||||
// Call Claude Code CLI
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse observations from response
|
||||
observations := ParseObservations(response, sdkSessionID)
|
||||
if len(observations) == 0 {
|
||||
log.Info().Str("tool", toolName).Msg("No observations extracted (Claude deemed not significant)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get existing observations for deduplication
|
||||
existingObs, err := p.observationStore.GetRecentObservations(ctx, project, 50)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get existing observations for dedup check")
|
||||
existingObs = nil // Continue without dedup
|
||||
}
|
||||
|
||||
// Store each observation (with deduplication check)
|
||||
const similarityThreshold = 0.4 // Same threshold as retrieval clustering
|
||||
var storedCount, skippedCount int
|
||||
|
||||
for _, obs := range observations {
|
||||
// Capture file modification times for staleness detection
|
||||
obs.FileMtimes = captureFileMtimes(obs.FilesRead, obs.FilesModified, cwd)
|
||||
|
||||
// Convert to stored observation for similarity check
|
||||
storedObs := obs.ToStoredObservation()
|
||||
|
||||
// Check if this observation is too similar to existing ones
|
||||
if existingObs != nil && similarity.IsSimilarToAny(storedObs, existingObs, similarityThreshold) {
|
||||
log.Debug().
|
||||
Str("type", string(obs.Type)).
|
||||
Str("title", obs.Title).
|
||||
Msg("Skipping observation - too similar to existing")
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
id, createdAtEpoch, err := p.observationStore.StoreObservation(ctx, sdkSessionID, project, obs, promptNumber, 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store observation")
|
||||
continue
|
||||
}
|
||||
|
||||
storedCount++
|
||||
log.Info().
|
||||
Int64("id", id).
|
||||
Str("type", string(obs.Type)).
|
||||
Str("title", obs.Title).
|
||||
Int("trackedFiles", len(obs.FileMtimes)).
|
||||
Msg("Observation stored")
|
||||
|
||||
// Sync to vector DB if callback is set
|
||||
if p.syncObservationFunc != nil {
|
||||
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
|
||||
fullObs.ID = id
|
||||
fullObs.CreatedAtEpoch = createdAtEpoch
|
||||
p.syncObservationFunc(fullObs)
|
||||
}
|
||||
|
||||
// Broadcast new observation event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
"type": "observation",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
"project": project,
|
||||
})
|
||||
|
||||
// Add to existing for subsequent dedup checks within same batch
|
||||
if existingObs != nil {
|
||||
existingObs = append(existingObs, storedObs)
|
||||
}
|
||||
}
|
||||
|
||||
if skippedCount > 0 {
|
||||
log.Info().
|
||||
Int("stored", storedCount).
|
||||
Int("skipped", skippedCount).
|
||||
Msg("Observation processing complete (duplicates skipped)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessSummary processes a session summary request.
|
||||
func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSessionID, project, userPrompt, lastUserMsg, lastAssistantMsg string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Skip summary generation if there's no meaningful assistant response
|
||||
// This prevents generic "initial session setup" summaries
|
||||
if !hasMeaningfulContent(lastAssistantMsg) {
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Msg("Skipping summary - no meaningful assistant response")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build the summary prompt
|
||||
req := SummaryRequest{
|
||||
SessionDBID: sessionDBID,
|
||||
SDKSessionID: sdkSessionID,
|
||||
Project: project,
|
||||
UserPrompt: userPrompt,
|
||||
LastUserMessage: lastUserMsg,
|
||||
LastAssistantMessage: lastAssistantMsg,
|
||||
}
|
||||
prompt := BuildSummaryPrompt(req)
|
||||
|
||||
// Call Claude Code CLI
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Int64("sessionId", sessionDBID).Msg("Failed to call Claude CLI for summary")
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse summary from response
|
||||
summary := ParseSummary(response, sessionDBID)
|
||||
if summary == nil {
|
||||
log.Info().Int64("sessionId", sessionDBID).Msg("No summary generated (skipped or empty)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter out summaries that describe the memory agent itself
|
||||
if isSelfReferentialSummary(summary) {
|
||||
log.Info().Int64("sessionId", sessionDBID).Msg("Skipping self-referential summary (describes agent, not user work)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store the summary (promptNumber=0, discoveryTokens=0 for summaries)
|
||||
id, createdAtEpoch, err := p.summaryStore.StoreSummary(ctx, sdkSessionID, project, summary, 0, 0)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store summary")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("id", id).
|
||||
Int64("sessionId", sessionDBID).
|
||||
Msg("Summary stored")
|
||||
|
||||
// Sync to vector DB if callback is set
|
||||
if p.syncSummaryFunc != nil {
|
||||
fullSummary := models.NewSessionSummary(sdkSessionID, project, summary, 0, 0)
|
||||
fullSummary.ID = id
|
||||
fullSummary.CreatedAtEpoch = createdAtEpoch
|
||||
p.syncSummaryFunc(fullSummary)
|
||||
}
|
||||
|
||||
// Broadcast new summary event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
"type": "summary",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
"project": project,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callClaudeCLI calls the Claude Code CLI with the given prompt.
|
||||
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
|
||||
// Build the full prompt with system instructions
|
||||
fullPrompt := systemPrompt + "\n\n" + prompt
|
||||
|
||||
// Create command with timeout
|
||||
ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Use claude CLI with --print flag for non-interactive output
|
||||
// and -p for prompt input
|
||||
cmd := exec.CommandContext(ctx, p.claudePath, "--print", "-p", fullPrompt) // #nosec G204 -- claudePath is from config, fullPrompt is internal
|
||||
|
||||
// Set model if specified (use haiku for cost efficiency)
|
||||
if p.model != "" {
|
||||
cmd.Args = append([]string{cmd.Args[0], "--model", p.model}, cmd.Args[1:]...)
|
||||
} else {
|
||||
// Default to haiku for processing (cheap and fast)
|
||||
cmd.Args = append([]string{cmd.Args[0], "--model", "haiku"}, cmd.Args[1:]...)
|
||||
}
|
||||
|
||||
// Run from /tmp to avoid triggering our own hooks
|
||||
// (hooks are triggered based on working directory)
|
||||
cmd.Dir = "/tmp"
|
||||
|
||||
// Disable any plugin hooks by setting an env var that our hooks can check
|
||||
cmd.Env = append(os.Environ(), "CLAUDE_MNEMONIC_INTERNAL=1")
|
||||
|
||||
// Capture output
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Run command
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("stderr", stderr.String()).
|
||||
Msg("Claude CLI execution failed")
|
||||
return "", fmt.Errorf("claude CLI failed: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.String(), nil
|
||||
}
|
||||
|
||||
// shouldSkipTool returns true for tools that aren't worth processing.
|
||||
func shouldSkipTool(toolName string) bool {
|
||||
// Only skip truly uninteresting tools
|
||||
skipTools := map[string]bool{
|
||||
"TodoWrite": true, // Skip TodoWrite - internal tracking
|
||||
"Task": true, // Skip Task - sub-agent spawning
|
||||
"TaskOutput": true, // Skip TaskOutput - sub-agent results
|
||||
"Glob": true, // Skip Glob - just file listing
|
||||
}
|
||||
|
||||
skip, found := skipTools[toolName]
|
||||
if found {
|
||||
return skip
|
||||
}
|
||||
return false // Process all other tools
|
||||
}
|
||||
|
||||
// extractFilePath extracts the file path from tool input for deduplication.
|
||||
func extractFilePath(toolName, inputStr string) string {
|
||||
if inputStr == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var input map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(inputStr), &input); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle different tool input formats
|
||||
switch toolName {
|
||||
case "Read":
|
||||
if fp, ok := input["file_path"].(string); ok {
|
||||
return fp
|
||||
}
|
||||
case "Grep", "Search":
|
||||
if path, ok := input["path"].(string); ok {
|
||||
return path
|
||||
}
|
||||
case "Edit", "Write":
|
||||
if fp, ok := input["file_path"].(string); ok {
|
||||
return fp
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// toJSONString converts an interface to a JSON string.
|
||||
func toJSONString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// captureFileMtimes captures current modification times for tracked files.
|
||||
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
|
||||
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64)
|
||||
|
||||
// Helper to get mtime for a file path
|
||||
getMtime := func(path string) (int64, bool) {
|
||||
// Resolve relative paths against cwd
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return info.ModTime().UnixMilli(), true
|
||||
}
|
||||
|
||||
// Capture mtimes for all read files
|
||||
for _, path := range filesRead {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
}
|
||||
}
|
||||
|
||||
// Capture mtimes for all modified files
|
||||
for _, path := range filesModified {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
}
|
||||
}
|
||||
|
||||
return mtimes
|
||||
}
|
||||
|
||||
// GetFileMtimes returns current modification times for a list of file paths.
|
||||
// This is used for staleness checking when injecting context.
|
||||
func GetFileMtimes(paths []string, cwd string) map[string]int64 {
|
||||
return captureFileMtimes(paths, nil, cwd)
|
||||
}
|
||||
|
||||
// GetFileContent reads file content for verification purposes.
|
||||
// Returns content and ok status.
|
||||
func GetFileContent(path, cwd string) (string, bool) {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Limit to first 2000 chars for verification (enough context, not too expensive)
|
||||
if len(content) > 2000 {
|
||||
return string(content[:2000]) + "\n...[truncated]", true
|
||||
}
|
||||
return string(content), true
|
||||
}
|
||||
|
||||
// VerifyObservation checks if an observation is still valid given the current file contents.
|
||||
// Returns true if the observation is still accurate, false if it should be deleted.
|
||||
func (p *Processor) VerifyObservation(ctx context.Context, obs *models.Observation, cwd string) bool {
|
||||
// Build file content context
|
||||
var fileContents []string
|
||||
var paths []string
|
||||
|
||||
// Combine files_read and files_modified
|
||||
for _, path := range obs.FilesRead {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
for _, path := range obs.FilesModified {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
|
||||
// Get current content of tracked files
|
||||
for _, path := range paths {
|
||||
if content, ok := GetFileContent(path, cwd); ok {
|
||||
fileContents = append(fileContents, fmt.Sprintf("=== %s ===\n%s", path, content))
|
||||
}
|
||||
}
|
||||
|
||||
if len(fileContents) == 0 {
|
||||
// No files available to verify against - keep the observation
|
||||
return true
|
||||
}
|
||||
|
||||
// Build verification prompt
|
||||
prompt := fmt.Sprintf(`You are verifying if a previously recorded observation is still accurate.
|
||||
|
||||
OBSERVATION:
|
||||
- Type: %s
|
||||
- Title: %s
|
||||
- Subtitle: %s
|
||||
- Narrative: %s
|
||||
- Facts: %v
|
||||
|
||||
CURRENT FILE CONTENTS:
|
||||
%s
|
||||
|
||||
TASK: Check if the observation is still accurate given the current file contents.
|
||||
Reply with ONLY one of:
|
||||
- VALID - if the observation is still accurate
|
||||
- INVALID - if the observation is no longer accurate (the code/behavior changed)
|
||||
- UNCERTAIN - if you can't determine validity (files might be incomplete)
|
||||
|
||||
Your response:`,
|
||||
obs.Type,
|
||||
obs.Title.String,
|
||||
obs.Subtitle.String,
|
||||
obs.Narrative.String,
|
||||
obs.Facts,
|
||||
strings.Join(fileContents, "\n\n"),
|
||||
)
|
||||
|
||||
// Call Claude CLI for quick verification
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to verify observation, keeping it")
|
||||
return true // On error, keep the observation
|
||||
}
|
||||
|
||||
response = strings.TrimSpace(strings.ToUpper(response))
|
||||
|
||||
// Parse response
|
||||
if strings.Contains(response, "INVALID") {
|
||||
log.Info().
|
||||
Int64("id", obs.ID).
|
||||
Str("title", obs.Title.String).
|
||||
Msg("Observation verified as INVALID - will delete")
|
||||
return false
|
||||
}
|
||||
|
||||
// VALID or UNCERTAIN - keep the observation
|
||||
log.Debug().
|
||||
Int64("id", obs.ID).
|
||||
Str("title", obs.Title.String).
|
||||
Str("result", response).
|
||||
Msg("Observation verified")
|
||||
return true
|
||||
}
|
||||
|
||||
// isSelfReferentialSummary checks if a summary describes the memory agent itself
|
||||
// rather than actual user work. These summaries should be filtered out.
|
||||
func isSelfReferentialSummary(summary *models.ParsedSummary) bool {
|
||||
// Combine all summary fields for checking
|
||||
content := strings.ToLower(summary.Request + " " + summary.Completed + " " + summary.Learned + " " + summary.NextSteps)
|
||||
|
||||
// Indicators that the summary is about the memory agent, not user work
|
||||
selfReferentialPhrases := []string{
|
||||
"memory extraction",
|
||||
"memory agent",
|
||||
"hook execution",
|
||||
"hook mechanism",
|
||||
"session initialization",
|
||||
"session setup",
|
||||
"agent initialization",
|
||||
"no technical learnings",
|
||||
"no code or project work",
|
||||
"waiting for the user",
|
||||
"waiting for user",
|
||||
"awaiting actual",
|
||||
"awaiting claude code",
|
||||
"progress checkpoint",
|
||||
"checkpoint request",
|
||||
}
|
||||
|
||||
matchCount := 0
|
||||
for _, phrase := range selfReferentialPhrases {
|
||||
if strings.Contains(content, phrase) {
|
||||
matchCount++
|
||||
}
|
||||
}
|
||||
|
||||
// If the summary mentions 2+ self-referential phrases, it's about the agent
|
||||
return matchCount >= 2
|
||||
}
|
||||
|
||||
// hasMeaningfulContent checks if the assistant response contains meaningful content
|
||||
// worth generating a summary for. This filters out initial greetings, empty sessions,
|
||||
// and sessions where only system messages were exchanged.
|
||||
func hasMeaningfulContent(assistantMsg string) bool {
|
||||
// Skip if empty or too short (need substantial content)
|
||||
if len(strings.TrimSpace(assistantMsg)) < 200 {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(assistantMsg)
|
||||
|
||||
// Skip messages that are primarily about system/hook status
|
||||
skipIndicators := []string{
|
||||
"hook success",
|
||||
"callback hook",
|
||||
"session start",
|
||||
"sessionstart",
|
||||
"system-reminder",
|
||||
"memory extraction agent",
|
||||
"memory agent",
|
||||
"no technical learnings",
|
||||
"waiting for",
|
||||
"waiting to",
|
||||
"no code or project work",
|
||||
"no substantive",
|
||||
}
|
||||
|
||||
skipCount := 0
|
||||
for _, skip := range skipIndicators {
|
||||
if strings.Contains(lowerMsg, skip) {
|
||||
skipCount++
|
||||
}
|
||||
}
|
||||
// If multiple skip indicators found, this is likely a system-only session
|
||||
if skipCount >= 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for indicators of actual work being done
|
||||
workIndicators := []string{
|
||||
// Concrete file operations (with paths)
|
||||
".go", ".ts", ".js", ".py", ".md", ".json", ".yaml", ".yml",
|
||||
// Code modifications
|
||||
"edited", "modified", "created", "deleted", "updated", "changed",
|
||||
"added", "removed", "fixed", "implemented", "refactored",
|
||||
// Tool results
|
||||
"```", "lines ", "function ", "const ", "var ", "let ",
|
||||
"type ", "struct ", "class ", "def ", "func ",
|
||||
}
|
||||
|
||||
matchCount := 0
|
||||
for _, indicator := range workIndicators {
|
||||
if strings.Contains(lowerMsg, strings.ToLower(indicator)) {
|
||||
matchCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Require at least 2 work indicators to generate a summary
|
||||
return matchCount >= 2
|
||||
}
|
||||
|
||||
const systemPrompt = `You are a memory extraction agent for Claude Code sessions. Your job is to analyze tool executions and extract meaningful observations that would be useful for future sessions.
|
||||
|
||||
GUIDELINES:
|
||||
1. Only create observations for SIGNIFICANT learnings - not every tool call needs one
|
||||
2. Focus on: decisions made, bugs fixed, patterns discovered, project structure learned
|
||||
3. Skip trivial operations like simple file reads without insights
|
||||
4. Be concise but informative in your observations
|
||||
5. Use appropriate type tags: decision, bugfix, feature, refactor, discovery, change
|
||||
|
||||
OUTPUT FORMAT:
|
||||
When you find something worth remembering, output:
|
||||
<observation>
|
||||
<type>decision|bugfix|feature|refactor|discovery|change</type>
|
||||
<title>Short descriptive title</title>
|
||||
<subtitle>One-line summary</subtitle>
|
||||
<narrative>Detailed explanation</narrative>
|
||||
<facts>
|
||||
<fact>Specific fact 1</fact>
|
||||
</facts>
|
||||
<concepts>
|
||||
<concept>tag1</concept>
|
||||
</concepts>
|
||||
<files_read>
|
||||
<file>/path/to/file</file>
|
||||
</files_read>
|
||||
<files_modified>
|
||||
<file>/path/to/file</file>
|
||||
</files_modified>
|
||||
</observation>
|
||||
|
||||
If the tool execution is not noteworthy, simply respond with:
|
||||
<skip reason="not significant"/>`
|
||||
@@ -0,0 +1,117 @@
|
||||
// Package sdk provides SDK agent integration for claude-mnemonic.
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ObservationTypes defines valid observation types.
|
||||
var ObservationTypes = []string{"bugfix", "feature", "refactor", "change", "discovery", "decision"}
|
||||
|
||||
// ObservationConcepts defines valid observation concepts.
|
||||
var ObservationConcepts = []string{
|
||||
"how-it-works",
|
||||
"why-it-exists",
|
||||
"what-changed",
|
||||
"problem-solution",
|
||||
"gotcha",
|
||||
"pattern",
|
||||
"trade-off",
|
||||
}
|
||||
|
||||
// ToolExecution represents a tool execution for observation.
|
||||
type ToolExecution struct {
|
||||
ID int64
|
||||
ToolName string
|
||||
ToolInput string
|
||||
ToolOutput string
|
||||
CreatedAtEpoch int64
|
||||
CWD string
|
||||
}
|
||||
|
||||
// BuildObservationPrompt builds a prompt for processing a tool observation.
|
||||
func BuildObservationPrompt(exec ToolExecution) string {
|
||||
// Safely parse tool_input and tool_output
|
||||
var toolInput interface{}
|
||||
var toolOutput interface{}
|
||||
|
||||
if err := json.Unmarshal([]byte(exec.ToolInput), &toolInput); err != nil {
|
||||
toolInput = exec.ToolInput
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(exec.ToolOutput), &toolOutput); err != nil {
|
||||
toolOutput = exec.ToolOutput
|
||||
}
|
||||
|
||||
inputJSON, _ := json.MarshalIndent(toolInput, " ", " ")
|
||||
outputJSON, _ := json.MarshalIndent(toolOutput, " ", " ")
|
||||
|
||||
timestamp := time.UnixMilli(exec.CreatedAtEpoch).Format(time.RFC3339)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<observed_from_primary_session>\n")
|
||||
sb.WriteString(fmt.Sprintf(" <what_happened>%s</what_happened>\n", exec.ToolName))
|
||||
sb.WriteString(fmt.Sprintf(" <occurred_at>%s</occurred_at>\n", timestamp))
|
||||
if exec.CWD != "" {
|
||||
sb.WriteString(fmt.Sprintf(" <working_directory>%s</working_directory>\n", exec.CWD))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" <parameters>%s</parameters>\n", truncate(string(inputJSON), 3000)))
|
||||
sb.WriteString(fmt.Sprintf(" <outcome>%s</outcome>\n", truncate(string(outputJSON), 5000)))
|
||||
sb.WriteString("</observed_from_primary_session>")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// SummaryRequest contains data for building a summary prompt.
|
||||
type SummaryRequest struct {
|
||||
SessionDBID int64
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
LastUserMessage string
|
||||
LastAssistantMessage string
|
||||
}
|
||||
|
||||
// BuildSummaryPrompt builds a prompt requesting a session summary.
|
||||
func BuildSummaryPrompt(req SummaryRequest) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("PROGRESS SUMMARY CHECKPOINT\n")
|
||||
sb.WriteString("===========================\n")
|
||||
sb.WriteString("Write progress notes of what was done, what was learned, and what's next. This is a checkpoint to capture progress so far. The session is ongoing - you may receive more requests and tool executions after this summary. Write \"next_steps\" as the current trajectory of work (what's actively being worked on or coming up next), not as post-session future work. Always write at least a minimal summary explaining current progress, even if work is still in early stages, so that users see a summary output tied to each request.\n\n")
|
||||
|
||||
if req.LastAssistantMessage != "" {
|
||||
sb.WriteString("Claude's Full Response to User:\n")
|
||||
sb.WriteString(truncate(req.LastAssistantMessage, 4000))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString(`Respond in this XML format:
|
||||
<summary>
|
||||
<request>[Short title capturing the user's request AND the substance of what was discussed/done]</request>
|
||||
<investigated>[What has been explored so far? What was examined?]</investigated>
|
||||
<learned>[What have you learned about how things work?]</learned>
|
||||
<completed>[What work has been completed so far? What has shipped or changed?]</completed>
|
||||
<next_steps>[What are you actively working on or planning to work on next in this session?]</next_steps>
|
||||
<notes>[Additional insights or observations about the current progress]</notes>
|
||||
</summary>
|
||||
|
||||
IMPORTANT! DO NOT do any work right now other than generating this next PROGRESS SUMMARY - and remember that you are a memory agent designed to summarize a DIFFERENT claude code session, not this one.
|
||||
|
||||
Never reference yourself or your own actions. Do not output anything other than the summary content formatted in the XML structure above. All other output is ignored by the system, and the system has been designed to be smart about token usage. Please spend your tokens wisely on useful summary content.
|
||||
|
||||
Thank you, this summary will be very useful for keeping track of our progress!`)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// truncate truncates a string to the specified length.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "... (truncated)"
|
||||
}
|
||||
@@ -0,0 +1,805 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/config"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/chroma"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/watcher"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sse"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Service configuration constants
|
||||
const (
|
||||
// DefaultHTTPTimeout is the default timeout for HTTP requests.
|
||||
DefaultHTTPTimeout = 30 * time.Second
|
||||
|
||||
// ReadyPollInterval is how often WaitReady checks initialization status.
|
||||
ReadyPollInterval = 50 * time.Millisecond
|
||||
|
||||
// StaleQueueSize is the buffer size for background stale verification.
|
||||
StaleQueueSize = 100
|
||||
|
||||
// QueueProcessInterval is how often the background queue processor runs.
|
||||
QueueProcessInterval = 2 * time.Second
|
||||
)
|
||||
|
||||
// RetrievalStats tracks observation retrieval metrics.
|
||||
type RetrievalStats struct {
|
||||
TotalRequests int64 // Total retrieval requests (inject + search)
|
||||
ObservationsServed int64 // Observations returned to clients
|
||||
VerifiedStale int64 // Stale observations that passed verification
|
||||
DeletedInvalid int64 // Invalid observations deleted
|
||||
SearchRequests int64 // Semantic search requests
|
||||
ContextInjections int64 // Session-start context injections
|
||||
}
|
||||
|
||||
// Service is the main worker service orchestrator.
|
||||
type Service struct {
|
||||
// Version of the worker binary
|
||||
version string
|
||||
|
||||
// Configuration
|
||||
config *config.Config
|
||||
|
||||
// Database
|
||||
store *sqlite.Store
|
||||
sessionStore *sqlite.SessionStore
|
||||
observationStore *sqlite.ObservationStore
|
||||
summaryStore *sqlite.SummaryStore
|
||||
promptStore *sqlite.PromptStore
|
||||
|
||||
// Domain services
|
||||
sessionManager *session.Manager
|
||||
sseBroadcaster *sse.Broadcaster
|
||||
processor *sdk.Processor
|
||||
|
||||
// Vector database
|
||||
chromaClient *chroma.Client
|
||||
chromaSync *chroma.Sync
|
||||
|
||||
// HTTP server
|
||||
router *chi.Mux
|
||||
server *http.Server
|
||||
startTime time.Time
|
||||
|
||||
// Retrieval statistics
|
||||
retrievalStats RetrievalStats
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Initialization state (for deferred init)
|
||||
ready atomic.Bool
|
||||
initError error
|
||||
initMu sync.RWMutex
|
||||
|
||||
// Background verification queue for stale observations
|
||||
staleQueue chan staleVerifyRequest
|
||||
staleQueueOnce sync.Once
|
||||
|
||||
// File watchers for auto-recreation on deletion
|
||||
dbWatcher *watcher.Watcher
|
||||
configWatcher *watcher.Watcher
|
||||
}
|
||||
|
||||
// staleVerifyRequest represents a request to verify a stale observation in background
|
||||
type staleVerifyRequest struct {
|
||||
observationID int64
|
||||
cwd string
|
||||
}
|
||||
|
||||
// NewService creates a new worker service with deferred initialization.
|
||||
// The service starts immediately with health endpoint available,
|
||||
// while database and SDK initialization happens in the background.
|
||||
func NewService(version string) (*Service, error) {
|
||||
cfg := config.Get()
|
||||
|
||||
// Create context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create router and SSE broadcaster (lightweight, no dependencies)
|
||||
router := chi.NewRouter()
|
||||
sseBroadcaster := sse.NewBroadcaster()
|
||||
|
||||
svc := &Service{
|
||||
version: version,
|
||||
config: cfg,
|
||||
sseBroadcaster: sseBroadcaster,
|
||||
router: router,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
// Setup middleware and routes (health endpoint works immediately)
|
||||
svc.setupMiddleware()
|
||||
svc.setupRoutes()
|
||||
|
||||
// Start async initialization
|
||||
go svc.initializeAsync()
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// initializeAsync performs heavy initialization in the background.
|
||||
func (s *Service) initializeAsync() {
|
||||
log.Info().Msg("Starting async initialization...")
|
||||
|
||||
// Ensure data directory, vector-db, and settings exist
|
||||
if err := config.EnsureAll(); err != nil {
|
||||
s.setInitError(fmt.Errorf("ensure data dir: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize database (this includes migrations - can be slow)
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: s.config.DBPath,
|
||||
MaxConns: s.config.MaxConns,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
s.setInitError(fmt.Errorf("init database: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Create store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
|
||||
// Create session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
|
||||
// Create ChromaDB client for vector search (optional - will be nil if unavailable)
|
||||
var chromaClient *chroma.Client
|
||||
var chromaSync *chroma.Sync
|
||||
chromaCfg := chroma.Config{
|
||||
Project: "default", // Collection prefix
|
||||
DataDir: s.config.VectorDBPath,
|
||||
BatchSize: 100,
|
||||
}
|
||||
client, err := chroma.NewClient(chromaCfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB client creation failed - vector sync disabled")
|
||||
} else {
|
||||
// Connect to ChromaDB (starts the MCP server)
|
||||
if err := client.Connect(s.ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB connection failed - vector sync disabled")
|
||||
} else {
|
||||
chromaClient = client
|
||||
chromaSync = chroma.NewSync(client)
|
||||
log.Info().Msg("ChromaDB client connected - vector sync enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// Create SDK processor (optional - will be nil if Claude CLI not available)
|
||||
var processor *sdk.Processor
|
||||
proc, err := sdk.NewProcessor(observationStore, summaryStore)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("SDK processor not available - observations will be queued but not processed")
|
||||
} else {
|
||||
processor = proc
|
||||
// Set broadcast callback for SSE events
|
||||
processor.SetBroadcastFunc(func(event map[string]interface{}) {
|
||||
s.sseBroadcaster.Broadcast(event)
|
||||
})
|
||||
log.Info().Msg("SDK processor initialized")
|
||||
}
|
||||
|
||||
// Set all the initialized components
|
||||
s.initMu.Lock()
|
||||
s.store = store
|
||||
s.sessionStore = sessionStore
|
||||
s.observationStore = observationStore
|
||||
s.summaryStore = summaryStore
|
||||
s.promptStore = promptStore
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.chromaClient = chromaClient
|
||||
s.chromaSync = chromaSync
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
if processor != nil && chromaSync != nil {
|
||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on observation store to sync deletes to ChromaDB
|
||||
if observationStore != nil && chromaSync != nil {
|
||||
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on prompt store to sync deletes to ChromaDB
|
||||
if promptStore != nil && chromaSync != nil {
|
||||
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set callback for session deletion
|
||||
sessionManager.SetOnSessionDeleted(func(id int64) {
|
||||
s.broadcastProcessingStatus()
|
||||
})
|
||||
|
||||
// Mark as ready
|
||||
s.ready.Store(true)
|
||||
log.Info().Msg("Async initialization complete - service ready")
|
||||
|
||||
// Start queue processor if SDK processor is available
|
||||
if processor != nil {
|
||||
s.wg.Add(1)
|
||||
go s.processQueue()
|
||||
}
|
||||
|
||||
// Start file watchers for auto-recreation on deletion
|
||||
s.startWatchers()
|
||||
}
|
||||
|
||||
// startWatchers initializes and starts file watchers for database and config.
|
||||
func (s *Service) startWatchers() {
|
||||
// Watch database file for deletion
|
||||
dbWatcher, err := watcher.New(s.config.DBPath, func() {
|
||||
log.Warn().Str("path", s.config.DBPath).Msg("Database file deleted, reinitializing...")
|
||||
s.reinitializeDatabase()
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create database watcher")
|
||||
} else {
|
||||
s.dbWatcher = dbWatcher
|
||||
if err := dbWatcher.Start(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to start database watcher")
|
||||
} else {
|
||||
log.Info().Str("path", s.config.DBPath).Msg("Database file watcher started")
|
||||
}
|
||||
}
|
||||
|
||||
// Watch config file for changes (triggers process exit for restart)
|
||||
configPath := config.SettingsPath()
|
||||
configWatcher, err := watcher.New(configPath, func() {
|
||||
log.Warn().Str("path", configPath).Msg("Config file changed, reloading...")
|
||||
s.reloadConfig()
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create config watcher")
|
||||
} else {
|
||||
s.configWatcher = configWatcher
|
||||
if err := configWatcher.Start(); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to start config watcher")
|
||||
} else {
|
||||
log.Info().Str("path", configPath).Msg("Config file watcher started")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reinitializeDatabase recreates the database after deletion.
|
||||
func (s *Service) reinitializeDatabase() {
|
||||
// Block new requests
|
||||
s.ready.Store(false)
|
||||
log.Info().Msg("Database reinitialization starting...")
|
||||
|
||||
// Get old store references
|
||||
s.initMu.Lock()
|
||||
oldStore := s.store
|
||||
oldSessionManager := s.sessionManager
|
||||
oldChromaClient := s.chromaClient
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Close old stores
|
||||
if oldChromaClient != nil {
|
||||
if err := oldChromaClient.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing old ChromaDB client")
|
||||
}
|
||||
}
|
||||
if oldStore != nil {
|
||||
if err := oldStore.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("Error closing old database")
|
||||
}
|
||||
}
|
||||
|
||||
// Clear in-memory sessions (they reference old DB IDs)
|
||||
if oldSessionManager != nil {
|
||||
oldSessionManager.ShutdownAll(s.ctx)
|
||||
}
|
||||
|
||||
// Ensure data directory, vector-db, and settings exist (may have been deleted)
|
||||
if err := config.EnsureAll(); err != nil {
|
||||
s.setInitError(fmt.Errorf("ensure data dir on reinit: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Create new database
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: s.config.DBPath,
|
||||
MaxConns: s.config.MaxConns,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
s.setInitError(fmt.Errorf("reinit database: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Create new store wrappers
|
||||
sessionStore := sqlite.NewSessionStore(store)
|
||||
observationStore := sqlite.NewObservationStore(store)
|
||||
summaryStore := sqlite.NewSummaryStore(store)
|
||||
promptStore := sqlite.NewPromptStore(store)
|
||||
|
||||
// Create new session manager
|
||||
sessionManager := session.NewManager(sessionStore)
|
||||
|
||||
// Recreate ChromaDB client
|
||||
var chromaClient *chroma.Client
|
||||
var chromaSync *chroma.Sync
|
||||
chromaCfg := chroma.Config{
|
||||
Project: "default",
|
||||
DataDir: s.config.VectorDBPath,
|
||||
BatchSize: 100,
|
||||
}
|
||||
client, err := chroma.NewClient(chromaCfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB client creation failed after reinit")
|
||||
} else {
|
||||
if err := client.Connect(s.ctx); err != nil {
|
||||
log.Warn().Err(err).Msg("ChromaDB connection failed after reinit")
|
||||
} else {
|
||||
chromaClient = client
|
||||
chromaSync = chroma.NewSync(client)
|
||||
log.Info().Msg("ChromaDB client reconnected after reinit")
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate SDK processor with new stores
|
||||
var processor *sdk.Processor
|
||||
proc, err := sdk.NewProcessor(observationStore, summaryStore)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("SDK processor not available after reinit")
|
||||
} else {
|
||||
processor = proc
|
||||
processor.SetBroadcastFunc(func(event map[string]interface{}) {
|
||||
s.sseBroadcaster.Broadcast(event)
|
||||
})
|
||||
}
|
||||
|
||||
// Atomically swap all components
|
||||
s.initMu.Lock()
|
||||
s.store = store
|
||||
s.sessionStore = sessionStore
|
||||
s.observationStore = observationStore
|
||||
s.summaryStore = summaryStore
|
||||
s.promptStore = promptStore
|
||||
s.sessionManager = sessionManager
|
||||
s.processor = processor
|
||||
s.chromaClient = chromaClient
|
||||
s.chromaSync = chromaSync
|
||||
s.initError = nil
|
||||
s.initMu.Unlock()
|
||||
|
||||
// Set vector sync callbacks on processor if both are available
|
||||
if processor != nil && chromaSync != nil {
|
||||
processor.SetSyncObservationFunc(func(obs *models.Observation) {
|
||||
if err := chromaSync.SyncObservation(s.ctx, obs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation to ChromaDB")
|
||||
}
|
||||
})
|
||||
processor.SetSyncSummaryFunc(func(summary *models.SessionSummary) {
|
||||
if err := chromaSync.SyncSummary(s.ctx, summary); err != nil {
|
||||
log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary to ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on observation store to sync deletes to ChromaDB
|
||||
if observationStore != nil && chromaSync != nil {
|
||||
observationStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteObservations(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete observations from ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set cleanup callback on prompt store to sync deletes to ChromaDB
|
||||
if promptStore != nil && chromaSync != nil {
|
||||
promptStore.SetCleanupFunc(func(ctx context.Context, deletedIDs []int64) {
|
||||
if err := chromaSync.DeleteUserPrompts(ctx, deletedIDs); err != nil {
|
||||
log.Warn().Err(err).Ints64("ids", deletedIDs).Msg("Failed to delete prompts from ChromaDB")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set callback for session deletion
|
||||
sessionManager.SetOnSessionDeleted(func(id int64) {
|
||||
s.broadcastProcessingStatus()
|
||||
})
|
||||
|
||||
// Mark as ready again
|
||||
s.ready.Store(true)
|
||||
log.Info().Msg("Database reinitialization complete")
|
||||
|
||||
// Broadcast status update
|
||||
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||
"type": "database_reinitialized",
|
||||
"message": "Database was recreated after deletion",
|
||||
})
|
||||
}
|
||||
|
||||
// reloadConfig reloads configuration from disk.
|
||||
// For now, this triggers a graceful restart by exiting (hooks will restart us).
|
||||
func (s *Service) reloadConfig() {
|
||||
log.Info().Msg("Config changed, triggering graceful restart...")
|
||||
|
||||
// Broadcast notification
|
||||
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||
"type": "config_changed",
|
||||
"message": "Configuration changed, restarting worker...",
|
||||
})
|
||||
|
||||
// Give SSE clients a moment to receive the message
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Exit cleanly - hooks will restart us with new config
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// setInitError records an initialization error.
|
||||
func (s *Service) setInitError(err error) {
|
||||
s.initMu.Lock()
|
||||
s.initError = err
|
||||
s.initMu.Unlock()
|
||||
log.Error().Err(err).Msg("Async initialization failed")
|
||||
}
|
||||
|
||||
// GetInitError returns any initialization error.
|
||||
func (s *Service) GetInitError() error {
|
||||
s.initMu.RLock()
|
||||
defer s.initMu.RUnlock()
|
||||
return s.initError
|
||||
}
|
||||
|
||||
// queueStaleVerification queues a stale observation for background verification.
|
||||
// This is non-blocking - if the queue is full, the request is dropped.
|
||||
func (s *Service) queueStaleVerification(observationID int64, cwd string) {
|
||||
// Initialize queue on first use
|
||||
s.staleQueueOnce.Do(func() {
|
||||
s.staleQueue = make(chan staleVerifyRequest, StaleQueueSize)
|
||||
s.wg.Add(1)
|
||||
go s.processStaleQueue()
|
||||
})
|
||||
|
||||
// Non-blocking send - drop if queue is full
|
||||
select {
|
||||
case s.staleQueue <- staleVerifyRequest{observationID: observationID, cwd: cwd}:
|
||||
// Queued
|
||||
default:
|
||||
// Queue full, drop
|
||||
log.Debug().Int64("id", observationID).Msg("Stale verification queue full, dropping")
|
||||
}
|
||||
}
|
||||
|
||||
// processStaleQueue processes stale observations in the background.
|
||||
func (s *Service) processStaleQueue() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case req := <-s.staleQueue:
|
||||
s.verifyStaleObservation(req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// verifyStaleObservation verifies a single stale observation in the background.
|
||||
func (s *Service) verifyStaleObservation(req staleVerifyRequest) {
|
||||
// Wait for service to be ready
|
||||
if !s.ready.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
// Get observation from DB
|
||||
s.initMu.RLock()
|
||||
store := s.observationStore
|
||||
processor := s.processor
|
||||
s.initMu.RUnlock()
|
||||
|
||||
if store == nil || processor == nil {
|
||||
return
|
||||
}
|
||||
|
||||
obs, err := store.GetObservationByID(s.ctx, req.observationID)
|
||||
if err != nil || obs == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify with Claude CLI (this is slow but we're in background)
|
||||
if !processor.VerifyObservation(s.ctx, obs, req.cwd) {
|
||||
// Invalid - delete it
|
||||
deleted, err := store.DeleteObservations(s.ctx, []int64{obs.ID})
|
||||
if err == nil && deleted > 0 {
|
||||
log.Info().
|
||||
Int64("id", obs.ID).
|
||||
Str("title", obs.Title.String).
|
||||
Msg("Background verification: deleted invalid observation")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int64("id", obs.ID).
|
||||
Msg("Background verification: observation still valid")
|
||||
}
|
||||
}
|
||||
|
||||
// setupMiddleware configures HTTP middleware.
|
||||
func (s *Service) setupMiddleware() {
|
||||
s.router.Use(middleware.Logger)
|
||||
s.router.Use(middleware.Recoverer)
|
||||
s.router.Use(middleware.Timeout(DefaultHTTPTimeout))
|
||||
s.router.Use(middleware.RealIP)
|
||||
}
|
||||
|
||||
// setupRoutes configures HTTP routes.
|
||||
func (s *Service) setupRoutes() {
|
||||
// Serve Vue dashboard from embedded static files
|
||||
s.router.Get("/", serveIndex)
|
||||
s.router.Get("/assets/*", serveAssets)
|
||||
|
||||
// Health check (both root and API-prefixed for compatibility)
|
||||
// Returns 200 immediately so hooks can connect quickly during init
|
||||
// Also returns version for stale worker detection
|
||||
s.router.Get("/health", s.handleHealth)
|
||||
s.router.Get("/api/health", s.handleHealth)
|
||||
|
||||
// Version endpoint for hooks to check if worker needs restart
|
||||
s.router.Get("/api/version", s.handleVersion)
|
||||
|
||||
// Readiness check - returns 200 only when fully initialized
|
||||
s.router.Get("/api/ready", s.handleReady)
|
||||
|
||||
// SSE endpoint (works before DB is ready)
|
||||
s.router.Get("/api/events", s.sseBroadcaster.HandleSSE)
|
||||
|
||||
// Routes that require DB to be ready
|
||||
s.router.Group(func(r chi.Router) {
|
||||
r.Use(s.requireReady)
|
||||
|
||||
// Session routes
|
||||
r.Post("/api/sessions/init", s.handleSessionInit)
|
||||
r.Get("/api/sessions", s.handleGetSessionByClaudeID)
|
||||
r.Post("/sessions/{id}/init", s.handleSessionStart)
|
||||
r.Post("/api/sessions/observations", s.handleObservation)
|
||||
r.Post("/api/sessions/subagent-complete", s.handleSubagentComplete)
|
||||
r.Post("/sessions/{id}/summarize", s.handleSummarize)
|
||||
|
||||
// Data routes
|
||||
r.Get("/api/observations", s.handleGetObservations)
|
||||
r.Get("/api/summaries", s.handleGetSummaries)
|
||||
r.Get("/api/prompts", s.handleGetPrompts)
|
||||
r.Get("/api/projects", s.handleGetProjects)
|
||||
r.Get("/api/stats", s.handleGetStats)
|
||||
r.Get("/api/stats/retrieval", s.handleGetRetrievalStats)
|
||||
|
||||
// Context injection
|
||||
r.Get("/api/context/count", s.handleContextCount)
|
||||
r.Get("/api/context/inject", s.handleContextInject)
|
||||
r.Get("/api/context/search", s.handleSearchByPrompt)
|
||||
})
|
||||
}
|
||||
|
||||
// recordRetrievalStats atomically updates retrieval statistics.
|
||||
func (s *Service) recordRetrievalStats(served, verified, deleted int64, isSearch bool) {
|
||||
atomic.AddInt64(&s.retrievalStats.TotalRequests, 1)
|
||||
atomic.AddInt64(&s.retrievalStats.ObservationsServed, served)
|
||||
atomic.AddInt64(&s.retrievalStats.VerifiedStale, verified)
|
||||
atomic.AddInt64(&s.retrievalStats.DeletedInvalid, deleted)
|
||||
if isSearch {
|
||||
atomic.AddInt64(&s.retrievalStats.SearchRequests, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&s.retrievalStats.ContextInjections, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRetrievalStats returns a copy of the retrieval stats.
|
||||
func (s *Service) GetRetrievalStats() RetrievalStats {
|
||||
return RetrievalStats{
|
||||
TotalRequests: atomic.LoadInt64(&s.retrievalStats.TotalRequests),
|
||||
ObservationsServed: atomic.LoadInt64(&s.retrievalStats.ObservationsServed),
|
||||
VerifiedStale: atomic.LoadInt64(&s.retrievalStats.VerifiedStale),
|
||||
DeletedInvalid: atomic.LoadInt64(&s.retrievalStats.DeletedInvalid),
|
||||
SearchRequests: atomic.LoadInt64(&s.retrievalStats.SearchRequests),
|
||||
ContextInjections: atomic.LoadInt64(&s.retrievalStats.ContextInjections),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the worker service.
|
||||
// The HTTP server starts immediately; database initialization happens async.
|
||||
func (s *Service) Start() error {
|
||||
port := config.GetWorkerPort()
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: s.router,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
if err := s.server.ListenAndServe(); err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("HTTP server error")
|
||||
}
|
||||
}()
|
||||
|
||||
// Note: Queue processor is started in initializeAsync() after DB is ready
|
||||
|
||||
log.Info().
|
||||
Int("port", port).
|
||||
Int("pid", getPID()).
|
||||
Msg("Worker HTTP server started (initialization in progress)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processQueue processes the observation queue in the background.
|
||||
func (s *Service) processQueue() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(QueueProcessInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.processAllSessions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processAllSessions processes pending messages for all active sessions.
|
||||
func (s *Service) processAllSessions() {
|
||||
// Get all sessions with pending messages
|
||||
sessions := s.sessionManager.GetAllSessions()
|
||||
|
||||
for _, sess := range sessions {
|
||||
// Get pending messages
|
||||
messages := s.sessionManager.DrainMessages(sess.SessionDBID)
|
||||
if len(messages) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process each message
|
||||
for _, msg := range messages {
|
||||
switch msg.Type {
|
||||
case session.MessageTypeObservation:
|
||||
if msg.Observation != nil {
|
||||
err := s.processor.ProcessObservation(
|
||||
s.ctx,
|
||||
sess.SDKSessionID,
|
||||
sess.Project,
|
||||
msg.Observation.ToolName,
|
||||
msg.Observation.ToolInput,
|
||||
msg.Observation.ToolResponse,
|
||||
msg.Observation.PromptNumber,
|
||||
msg.Observation.CWD,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("tool", msg.Observation.ToolName).
|
||||
Msg("Failed to process observation")
|
||||
}
|
||||
}
|
||||
|
||||
case session.MessageTypeSummarize:
|
||||
if msg.Summarize != nil {
|
||||
err := s.processor.ProcessSummary(
|
||||
s.ctx,
|
||||
sess.SessionDBID,
|
||||
sess.SDKSessionID,
|
||||
sess.Project,
|
||||
sess.UserPrompt,
|
||||
msg.Summarize.LastUserMessage,
|
||||
msg.Summarize.LastAssistantMessage,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Int64("sessionId", sess.SessionDBID).
|
||||
Msg("Failed to process summary")
|
||||
}
|
||||
// Delete session after summary
|
||||
s.sessionManager.DeleteSession(sess.SessionDBID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the service.
|
||||
func (s *Service) Shutdown(ctx context.Context) error {
|
||||
s.cancel()
|
||||
|
||||
// Stop file watchers
|
||||
if s.dbWatcher != nil {
|
||||
_ = s.dbWatcher.Stop()
|
||||
}
|
||||
if s.configWatcher != nil {
|
||||
_ = s.configWatcher.Stop()
|
||||
}
|
||||
|
||||
// Shutdown all sessions
|
||||
s.sessionManager.ShutdownAll(ctx)
|
||||
|
||||
// Shutdown HTTP server
|
||||
if s.server != nil {
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("HTTP server shutdown error")
|
||||
}
|
||||
}
|
||||
|
||||
// Close ChromaDB client
|
||||
if s.chromaClient != nil {
|
||||
if err := s.chromaClient.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("ChromaDB close error")
|
||||
}
|
||||
}
|
||||
|
||||
// Close database
|
||||
if err := s.store.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("Database close error")
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
log.Info().Msg("Worker service shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// broadcastProcessingStatus broadcasts the current processing status.
|
||||
func (s *Service) broadcastProcessingStatus() {
|
||||
isProcessing := s.sessionManager.IsAnySessionProcessing()
|
||||
queueDepth := s.sessionManager.GetTotalQueueDepth()
|
||||
|
||||
s.sseBroadcaster.Broadcast(map[string]interface{}{
|
||||
"type": "processing_status",
|
||||
"isProcessing": isProcessing,
|
||||
"queueDepth": queueDepth,
|
||||
})
|
||||
}
|
||||
|
||||
func getPID() int {
|
||||
return os.Getpid()
|
||||
}
|
||||
@@ -0,0 +1,346 @@
|
||||
// Package session provides session lifecycle management for claude-mnemonic.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MessageType represents the type of pending message.
|
||||
type MessageType int
|
||||
|
||||
const (
|
||||
MessageTypeObservation MessageType = iota
|
||||
MessageTypeSummarize
|
||||
)
|
||||
|
||||
// ObservationData contains data for a tool observation.
|
||||
type ObservationData struct {
|
||||
ToolName string
|
||||
ToolInput interface{}
|
||||
ToolResponse interface{}
|
||||
PromptNumber int
|
||||
CWD string
|
||||
}
|
||||
|
||||
// SummarizeData contains data for a summarize request.
|
||||
type SummarizeData struct {
|
||||
LastUserMessage string
|
||||
LastAssistantMessage string
|
||||
}
|
||||
|
||||
// PendingMessage represents a message queued for SDK processing.
|
||||
type PendingMessage struct {
|
||||
Type MessageType
|
||||
Observation *ObservationData
|
||||
Summarize *SummarizeData
|
||||
}
|
||||
|
||||
// ActiveSession represents an in-memory active session being processed.
|
||||
type ActiveSession struct {
|
||||
SessionDBID int64
|
||||
ClaudeSessionID string
|
||||
SDKSessionID string
|
||||
Project string
|
||||
UserPrompt string
|
||||
LastPromptNumber int
|
||||
StartTime time.Time
|
||||
CumulativeInputTokens int64
|
||||
CumulativeOutputTokens int64
|
||||
|
||||
// Concurrency control
|
||||
pendingMessages []PendingMessage
|
||||
messageMu sync.Mutex
|
||||
notify chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
generatorActive atomic.Bool
|
||||
}
|
||||
|
||||
// Manager manages active session lifecycles.
|
||||
type Manager struct {
|
||||
sessionStore *sqlite.SessionStore
|
||||
sessions map[int64]*ActiveSession
|
||||
mu sync.RWMutex
|
||||
onDeleted func(int64)
|
||||
}
|
||||
|
||||
// NewManager creates a new session manager.
|
||||
func NewManager(sessionStore *sqlite.SessionStore) *Manager {
|
||||
return &Manager{
|
||||
sessionStore: sessionStore,
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnSessionDeleted sets a callback for when a session is deleted.
|
||||
func (m *Manager) SetOnSessionDeleted(callback func(int64)) {
|
||||
m.onDeleted = callback
|
||||
}
|
||||
|
||||
// InitializeSession initializes a session, creating it if needed.
|
||||
func (m *Manager) InitializeSession(ctx context.Context, sessionDBID int64, userPrompt string, promptNumber int) (*ActiveSession, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if already active
|
||||
if session, ok := m.sessions[sessionDBID]; ok {
|
||||
// Update user prompt for continuation
|
||||
if userPrompt != "" {
|
||||
session.UserPrompt = userPrompt
|
||||
session.LastPromptNumber = promptNumber
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Fetch from database
|
||||
dbSession, err := m.sessionStore.GetSessionByID(ctx, sessionDBID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dbSession == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Use provided userPrompt or fall back to database
|
||||
prompt := userPrompt
|
||||
if prompt == "" && dbSession.UserPrompt.Valid {
|
||||
prompt = dbSession.UserPrompt.String
|
||||
}
|
||||
|
||||
// Get prompt counter if not provided
|
||||
if promptNumber <= 0 {
|
||||
promptNumber, _ = m.sessionStore.GetPromptCounter(ctx, sessionDBID)
|
||||
}
|
||||
|
||||
// Create session context
|
||||
sessionCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
session := &ActiveSession{
|
||||
SessionDBID: sessionDBID,
|
||||
ClaudeSessionID: dbSession.ClaudeSessionID,
|
||||
SDKSessionID: dbSession.SDKSessionID.String,
|
||||
Project: dbSession.Project,
|
||||
UserPrompt: prompt,
|
||||
LastPromptNumber: promptNumber,
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0, 32),
|
||||
notify: make(chan struct{}, 1),
|
||||
ctx: sessionCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
m.sessions[sessionDBID] = session
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("project", session.Project).
|
||||
Str("claudeSessionId", session.ClaudeSessionID).
|
||||
Msg("Session initialized")
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// QueueObservation queues an observation for SDK processing.
|
||||
func (m *Manager) QueueObservation(ctx context.Context, sessionDBID int64, data ObservationData) error {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
// Auto-initialize from database
|
||||
m.mu.Unlock()
|
||||
var err error
|
||||
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
|
||||
if err != nil || session == nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
session.pendingMessages = append(session.pendingMessages, PendingMessage{
|
||||
Type: MessageTypeObservation,
|
||||
Observation: &data,
|
||||
})
|
||||
queueDepth := len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
|
||||
// Non-blocking notification
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("tool", data.ToolName).
|
||||
Int("queueDepth", queueDepth).
|
||||
Msg("Observation queued")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueSummarize queues a summarize request for SDK processing.
|
||||
func (m *Manager) QueueSummarize(ctx context.Context, sessionDBID int64, lastUserMessage, lastAssistantMessage string) error {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
// Auto-initialize from database
|
||||
m.mu.Unlock()
|
||||
var err error
|
||||
session, err = m.InitializeSession(ctx, sessionDBID, "", 0)
|
||||
if err != nil || session == nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
session.pendingMessages = append(session.pendingMessages, PendingMessage{
|
||||
Type: MessageTypeSummarize,
|
||||
Summarize: &SummarizeData{
|
||||
LastUserMessage: lastUserMessage,
|
||||
LastAssistantMessage: lastAssistantMessage,
|
||||
},
|
||||
})
|
||||
queueDepth := len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
|
||||
// Non-blocking notification
|
||||
select {
|
||||
case session.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Int("queueDepth", queueDepth).
|
||||
Msg("Summarize request queued")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSession removes a session and cleans up resources.
|
||||
func (m *Manager) DeleteSession(sessionDBID int64) {
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m.sessions, sessionDBID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// Cancel context to stop generator
|
||||
session.cancel()
|
||||
|
||||
duration := time.Since(session.StartTime)
|
||||
log.Info().
|
||||
Int64("sessionId", sessionDBID).
|
||||
Str("project", session.Project).
|
||||
Dur("duration", duration).
|
||||
Msg("Session deleted")
|
||||
|
||||
// Trigger callback
|
||||
if m.onDeleted != nil {
|
||||
m.onDeleted(sessionDBID)
|
||||
}
|
||||
}
|
||||
|
||||
// ShutdownAll shuts down all active sessions.
|
||||
func (m *Manager) ShutdownAll(ctx context.Context) {
|
||||
m.mu.Lock()
|
||||
sessionIDs := make([]int64, 0, len(m.sessions))
|
||||
for id := range m.sessions {
|
||||
sessionIDs = append(sessionIDs, id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, id := range sessionIDs {
|
||||
m.DeleteSession(id)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int("count", len(sessionIDs)).
|
||||
Msg("All sessions shut down")
|
||||
}
|
||||
|
||||
// GetActiveSessionCount returns the number of active sessions.
|
||||
func (m *Manager) GetActiveSessionCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// GetTotalQueueDepth returns the total queue depth across all sessions.
|
||||
func (m *Manager) GetTotalQueueDepth() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
total := 0
|
||||
for _, session := range m.sessions {
|
||||
session.messageMu.Lock()
|
||||
total += len(session.pendingMessages)
|
||||
session.messageMu.Unlock()
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// IsAnySessionProcessing returns true if any session is actively processing.
|
||||
func (m *Manager) IsAnySessionProcessing() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, session := range m.sessions {
|
||||
// Check for pending messages
|
||||
session.messageMu.Lock()
|
||||
hasPending := len(session.pendingMessages) > 0
|
||||
session.messageMu.Unlock()
|
||||
if hasPending {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for active generator
|
||||
if session.generatorActive.Load() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAllSessions returns a copy of all active sessions.
|
||||
func (m *Manager) GetAllSessions() []*ActiveSession {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
sessions := make([]*ActiveSession, 0, len(m.sessions))
|
||||
for _, session := range m.sessions {
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
// DrainMessages drains and returns all pending messages for a session.
|
||||
func (m *Manager) DrainMessages(sessionDBID int64) []PendingMessage {
|
||||
m.mu.RLock()
|
||||
session, ok := m.sessions[sessionDBID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
session.messageMu.Lock()
|
||||
messages := make([]PendingMessage, len(session.pendingMessages))
|
||||
copy(messages, session.pendingMessages)
|
||||
session.pendingMessages = session.pendingMessages[:0]
|
||||
session.messageMu.Unlock()
|
||||
|
||||
return messages
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package sse provides Server-Sent Events broadcasting for claude-mnemonic.
|
||||
package sse
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Client represents a connected SSE client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Writer http.ResponseWriter
|
||||
Flusher http.Flusher
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
// Broadcaster manages SSE client connections and message broadcasting.
|
||||
type Broadcaster struct {
|
||||
clients map[string]*Client
|
||||
mu sync.RWMutex
|
||||
nextID int
|
||||
}
|
||||
|
||||
// NewBroadcaster creates a new SSE broadcaster.
|
||||
func NewBroadcaster() *Broadcaster {
|
||||
return &Broadcaster{
|
||||
clients: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// AddClient adds a new SSE client connection.
|
||||
func (b *Broadcaster) AddClient(w http.ResponseWriter) (*Client, error) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("streaming not supported")
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.nextID++
|
||||
id := fmt.Sprintf("client-%d", b.nextID)
|
||||
client := &Client{
|
||||
ID: id,
|
||||
Writer: w,
|
||||
Flusher: flusher,
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
b.clients[id] = client
|
||||
clientCount := len(b.clients)
|
||||
b.mu.Unlock()
|
||||
|
||||
log.Debug().
|
||||
Str("clientId", id).
|
||||
Int("totalClients", clientCount).
|
||||
Msg("SSE client connected")
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// RemoveClient removes a client connection.
|
||||
func (b *Broadcaster) RemoveClient(client *Client) {
|
||||
b.mu.Lock()
|
||||
delete(b.clients, client.ID)
|
||||
clientCount := len(b.clients)
|
||||
b.mu.Unlock()
|
||||
|
||||
close(client.Done)
|
||||
|
||||
log.Debug().
|
||||
Str("clientId", client.ID).
|
||||
Int("totalClients", clientCount).
|
||||
Msg("SSE client disconnected")
|
||||
}
|
||||
|
||||
// Broadcast sends a message to all connected clients.
|
||||
func (b *Broadcaster) Broadcast(data interface{}) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal SSE data")
|
||||
return
|
||||
}
|
||||
|
||||
message := fmt.Sprintf("data: %s\n\n", jsonData)
|
||||
|
||||
b.mu.RLock()
|
||||
clients := make([]*Client, 0, len(b.clients))
|
||||
for _, client := range b.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
for _, client := range clients {
|
||||
select {
|
||||
case <-client.Done:
|
||||
continue
|
||||
default:
|
||||
_, err := client.Writer.Write([]byte(message))
|
||||
if err != nil {
|
||||
log.Debug().
|
||||
Str("clientId", client.ID).
|
||||
Err(err).
|
||||
Msg("Failed to write to SSE client")
|
||||
continue
|
||||
}
|
||||
client.Flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCount returns the number of connected clients.
|
||||
func (b *Broadcaster) ClientCount() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return len(b.clients)
|
||||
}
|
||||
|
||||
// HandleSSE handles an SSE connection request.
|
||||
func (b *Broadcaster) HandleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
// Set SSE headers
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
client, err := b.AddClient(w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer b.RemoveClient(client)
|
||||
|
||||
// Send initial connection message
|
||||
fmt.Fprintf(w, "data: {\"type\":\"connected\",\"clientId\":\"%s\"}\n\n", client.ID)
|
||||
client.Flusher.Flush()
|
||||
|
||||
// Wait for client disconnect
|
||||
<-r.Context().Done()
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
var staticFS embed.FS
|
||||
|
||||
// staticSubFS is the static subdirectory filesystem
|
||||
var staticSubFS fs.FS
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
staticSubFS, err = fs.Sub(staticFS, "static")
|
||||
if err != nil {
|
||||
panic("failed to create sub filesystem: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// serveIndex serves the index.html file for the root path
|
||||
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
content, err := fs.ReadFile(staticSubFS, "index.html")
|
||||
if err != nil {
|
||||
http.Error(w, "Dashboard not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
_, _ = w.Write(content)
|
||||
}
|
||||
|
||||
// serveAssets serves static assets from the embedded filesystem
|
||||
func serveAssets(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the /assets/ prefix and serve the file
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
|
||||
content, err := fs.ReadFile(staticSubFS, path)
|
||||
if err != nil {
|
||||
http.Error(w, "Asset not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Set content type based on extension
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
w.Header().Set("Content-Type", "application/javascript")
|
||||
} else if strings.HasSuffix(path, ".css") {
|
||||
w.Header().Set("Content-Type", "text/css")
|
||||
}
|
||||
|
||||
// No caching - always serve fresh content
|
||||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
_, _ = w.Write(content)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// testStore creates a sqlite.Store with a temporary database for testing.
|
||||
// Uses sqlite.NewStore which runs migrations (requires FTS5).
|
||||
// Skips the test if FTS5 is not available.
|
||||
func testStore(t *testing.T) (*sqlite.Store, func()) {
|
||||
t.Helper()
|
||||
|
||||
// First check if FTS5 is available
|
||||
if !hasFTS5ForTest(t) {
|
||||
t.Skip("FTS5 not available in this SQLite build")
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "claude-mnemonic-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
dbPath := tmpDir + "/test.db"
|
||||
|
||||
store, err := sqlite.NewStore(sqlite.StoreConfig{
|
||||
Path: dbPath,
|
||||
MaxConns: 1,
|
||||
WALMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
t.Fatalf("create store: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = store.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
return store, cleanup
|
||||
}
|
||||
|
||||
// hasFTS5ForTest checks if FTS5 is available in the SQLite build.
|
||||
func hasFTS5ForTest(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "fts5-check-*")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tmpDir) }()
|
||||
|
||||
dbPath := tmpDir + "/check.db"
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
_, err = db.Exec("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(content)")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_, _ = db.Exec("DROP TABLE IF EXISTS fts5_test")
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user