Make things 'betterer' across the board (#23)

* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
Lukasz Raczylo
2026-01-11 01:51:20 +00:00
committed by GitHub
parent 3107eddeb2
commit d04b60517a
46 changed files with 12710 additions and 2068 deletions
+125 -1273
View File
File diff suppressed because it is too large Load Diff
+677
View File
@@ -0,0 +1,677 @@
// Package worker provides context and search-related HTTP handlers.
package worker
import (
"context"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// handleSearchByPrompt searches observations relevant to a user prompt.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
cwd := r.URL.Query().Get("cwd")
if project == "" || query == "" {
http.Error(w, "project and query required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200)
var observations []*models.Observation
var err error
var usedVector bool
similarityScores := make(map[int64]float64) // Track similarity per observation
// Get threshold settings from config
threshold := s.config.ContextRelevanceThreshold
maxResults := s.config.ContextMaxPromptResults
// Generate expanded queries if query expander is available
// Use timeout context to prevent query expansion from blocking
var expandedQueries []expansion.ExpandedQuery
var detectedIntent string
if s.queryExpander != nil {
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
cfg := expansion.DefaultConfig()
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here)
if len(expandedQueries) > 0 {
detectedIntent = string(expandedQueries[0].Intent)
}
}
if len(expandedQueries) == 0 {
// Fallback to just the original query
expandedQueries = []expansion.ExpandedQuery{
{Query: query, Weight: 1.0, Source: "original"},
}
}
// Try vector search first if available
var vectorSearchFailed bool
if s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
// Search with each expanded query and merge results
// Pre-allocate with estimated capacity to avoid repeated reallocation
estimatedCapacity := len(expandedQueries) * limit * 2
allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity)
queryWeights := make(map[string]float64, len(expandedQueries))
var vectorErrors int
for _, eq := range expandedQueries {
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
if vecErr != nil {
vectorErrors++
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
} else if len(vectorResults) > 0 {
// Apply weight to similarity scores before merging
for i := range vectorResults {
vectorResults[i].Similarity *= eq.Weight
}
allVectorResults = append(allVectorResults, vectorResults...)
queryWeights[eq.Query] = eq.Weight
}
}
// Track if vector search had issues
if vectorErrors > 0 && vectorErrors == len(expandedQueries) {
vectorSearchFailed = true
log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS")
}
if len(allVectorResults) > 0 {
// Filter by relevance threshold before extracting IDs
// Use a slightly lower threshold for expanded queries
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
// Build similarity map for filtered results (keeping highest weighted score per observation)
for _, vr := range filteredResults {
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
id := int64(sqliteID)
// Keep the highest score for each observation
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
similarityScores[id] = vr.Similarity
}
}
}
// Extract observation IDs with project/scope filtering using shared helper
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
if len(obsIDs) > 0 {
// Fetch full observations from SQLite
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to FTS if vector search not available, failed, or returned no results
if !usedVector || len(observations) == 0 {
if vectorSearchFailed {
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
}
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
if err != nil {
// FTS might fail if query has special chars, try without
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}
// Fast staleness filter - NO verification (that's too slow for interactive use)
// Just check mtimes and exclude obviously stale observations
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 && cwd != "" {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Apply cross-encoder reranking if available
var reranked bool
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
// Build candidates from observations with their bi-encoder scores
candidates := make([]reranking.Candidate, len(freshObservations))
for i, obs := range freshObservations {
// Use strings.Builder for efficient concatenation
var content string
if obs.Narrative.Valid && obs.Narrative.String != "" {
var sb strings.Builder
sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String))
sb.WriteString(obs.Title.String)
sb.WriteByte(' ')
sb.WriteString(obs.Narrative.String)
content = sb.String()
} else {
content = obs.Title.String
}
candidates[i] = reranking.Candidate{
ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf
Content: content,
Score: similarityScores[obs.ID],
Metadata: map[string]any{"obs_idx": i},
}
}
// Rerank using cross-encoder - use pure mode or combined scores
var rerankResults []reranking.RerankResult
var rerankErr error
if s.config.RerankingPureMode {
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
} else {
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
}
if rerankErr != nil {
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
} else if len(rerankResults) > 0 {
// Update similarity scores with reranked scores
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
similarityScores[id] = rr.CombinedScore
}
}
// Reorder observations based on rerank results
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
obsMap := make(map[int64]*models.Observation)
for _, obs := range freshObservations {
obsMap[obs.ID] = obs
}
for _, rr := range rerankResults {
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
if obs, ok := obsMap[id]; ok {
reorderedObs = append(reorderedObs, obs)
}
}
}
freshObservations = reorderedObs
reranked = true
log.Debug().
Int("candidates", len(candidates)).
Int("returned", len(rerankResults)).
Msg("Cross-encoder reranking complete")
}
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Sort by similarity score (highest first) if we have scores and didn't rerank
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
sort.Slice(clusteredObservations, func(i, j int) bool {
scoreI := similarityScores[clusteredObservations[i].ID]
scoreJ := similarityScores[clusteredObservations[j].ID]
return scoreI > scoreJ
})
}
// Apply max results cap if configured
if maxResults > 0 && len(clusteredObservations) > maxResults {
clusteredObservations = clusteredObservations[:maxResults]
}
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Str("query", query).
Str("intent", detectedIntent).
Int("expansions", len(expandedQueries)).
Int("found", len(clusteredObservations)).
Int("stale_excluded", staleCount).
Float64("threshold", threshold).
Msg("Prompt-based observation search")
// Build response with similarity scores
obsWithScores := make([]map[string]any, len(clusteredObservations))
for i, obs := range clusteredObservations {
obsMap := obs.ToMap()
if score, ok := similarityScores[obs.ID]; ok {
obsMap["similarity"] = score
}
obsWithScores[i] = obsMap
}
// Build expansion info for response
expansionInfo := make([]map[string]any, len(expandedQueries))
for i, eq := range expandedQueries {
expansionInfo[i] = map[string]any{
"query": eq.Query,
"weight": eq.Weight,
"source": eq.Source,
}
}
// Track this search for analytics
s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector)
writeJSON(w, map[string]any{
"project": project,
"query": query,
"intent": detectedIntent,
"expansions": expansionInfo,
"observations": obsWithScores,
"threshold": threshold,
"max_results": maxResults,
})
}
// handleFileContext returns observations relevant to specific files being worked on.
// Uses vector similarity search to find observations that mention or relate to the given files.
func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
filesParam := r.URL.Query().Get("files")
if filesParam == "" {
http.Error(w, "files required", http.StatusBadRequest)
return
}
// Parse comma-separated file paths
files := strings.Split(filesParam, ",")
if len(files) == 0 {
http.Error(w, "at least one file required", http.StatusBadRequest)
return
}
// Limit to reasonable number of files
maxFiles := 20
if len(files) > maxFiles {
files = files[:maxFiles]
}
// Get limit parameter (default 10 per file)
limitStr := r.URL.Query().Get("limit")
limit := 10
if limitStr != "" {
if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 {
limit = parsed
}
}
// Search for observations related to each file in parallel
ctx := r.Context()
// Check if vector search is available
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
writeJSON(w, map[string]any{
"files": files,
"results": map[string]any{},
"count": 0,
"error": "vector search not available",
})
return
}
// Prepare for parallel execution
type fileResult struct {
file string
results []map[string]any
obsIDs []int64 // Track observation IDs for deduplication
}
resultsChan := make(chan fileResult, len(files))
sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches
var wg sync.WaitGroup
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
wg.Add(1)
go func(file string) {
defer wg.Done()
sem <- struct{}{} // Acquire semaphore
defer func() { <-sem }() // Release semaphore
// Build search query from file path
query := buildFileQuery(file)
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
if vecErr != nil {
log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context")
return
}
// Extract observation IDs from vector results
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) == 0 {
return
}
// Fetch observations
observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2)
if err != nil {
log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context")
return
}
// Pre-build score map from vector results (O(n) instead of O(n²))
scoreMap := make(map[int64]float64, len(vectorResults))
var avgScore float64
for _, vr := range vectorResults {
avgScore += vr.Similarity
// Parse observation ID from vector result ID (format: obs_{id}_{field})
// Use index-based parsing to avoid slice allocation from strings.Split
if len(vr.ID) > 4 && vr.ID[:4] == "obs_" {
rest := vr.ID[4:] // Skip "obs_"
underscoreIdx := strings.IndexByte(rest, '_')
var idStr string
if underscoreIdx >= 0 {
idStr = rest[:underscoreIdx]
} else {
idStr = rest
}
if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil {
// Keep highest score for each observation
if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing {
scoreMap[id] = vr.Similarity
}
}
}
}
if len(vectorResults) > 0 {
avgScore /= float64(len(vectorResults))
}
fileResults := make([]map[string]any, 0, limit)
var usedIDs []int64
for _, obs := range observations {
// Check project scope
if obs.Scope == "project" && obs.Project != project {
continue
}
// O(1) score lookup instead of O(n) nested loop
score, found := scoreMap[obs.ID]
if !found {
// Use average score as fallback
score = avgScore
}
// Only include if score is above threshold
if score < 0.3 {
continue
}
fileResults = append(fileResults, map[string]any{
"id": obs.ID,
"title": obs.Title.String,
"type": obs.Type,
"narrative": obs.Narrative.String,
"facts": obs.Facts,
"score": score,
})
usedIDs = append(usedIDs, obs.ID)
if len(fileResults) >= limit {
break
}
}
if len(fileResults) > 0 {
resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs}
}
}(file)
}
// Close channel when all goroutines complete
go func() {
wg.Wait()
close(resultsChan)
}()
// Collect results and deduplicate
allResults := make(map[string]any)
seenObservationIDs := make(map[int64]bool)
for res := range resultsChan {
// Filter out duplicates that were already seen in other files
dedupedResults := make([]map[string]any, 0, len(res.results))
for i, r := range res.results {
obsID := res.obsIDs[i]
if !seenObservationIDs[obsID] {
seenObservationIDs[obsID] = true
dedupedResults = append(dedupedResults, r)
}
}
if len(dedupedResults) > 0 {
allResults[res.file] = dedupedResults
}
}
writeJSON(w, map[string]any{
"files": files,
"results": allResults,
"count": len(allResults),
})
}
// buildFileQuery extracts meaningful search terms from a file path.
func buildFileQuery(filePath string) string {
// Remove common prefixes and extensions
path := strings.TrimPrefix(filePath, "/")
// Extract the filename and directory
parts := strings.Split(path, "/")
meaningful := make([]string, 0, len(parts))
for _, part := range parts {
// Skip common directory names that aren't meaningful
switch strings.ToLower(part) {
case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs":
continue
default:
// Remove file extension
if idx := strings.LastIndex(part, "."); idx > 0 {
part = part[:idx]
}
// Convert camelCase/PascalCase to spaces
part = splitCamelCase(part)
// Convert snake_case to spaces
part = strings.ReplaceAll(part, "_", " ")
// Convert kebab-case to spaces
part = strings.ReplaceAll(part, "-", " ")
meaningful = append(meaningful, part)
}
}
return strings.Join(meaningful, " ")
}
// splitCamelCase splits camelCase or PascalCase into separate words.
func splitCamelCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune(' ')
}
result.WriteRune(r)
}
return result.String()
}
// handleContextInject returns context for injection at session start.
// IMPORTANT: This is on the critical startup path - must be fast!
// No synchronous verification - just filter by staleness and return.
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
cwd := r.URL.Query().Get("cwd")
if cwd == "" {
cwd = "/"
}
// Limit observations for fast startup (configurable, default 100)
limit := s.config.ContextObservations
if limit <= 0 {
limit = DefaultContextLimit
}
// Full count determines how many observations get full detail (configurable, default 25)
fullCount := s.config.ContextFullCount
if fullCount <= 0 {
fullCount = 25
}
// Get recent observations
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Fast staleness filter - NO verification (that's too slow for startup)
var staleCount int
freshObservations := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if len(obs.FileMtimes) > 0 {
var paths []string
for path := range obs.FileMtimes {
paths = append(paths, path)
}
currentMtimes := sdk.GetFileMtimes(paths, cwd)
if obs.CheckStaleness(currentMtimes) {
// Stale - exclude but don't verify (too slow)
// Queue for background verification instead
staleCount++
s.queueStaleVerification(obs.ID, cwd)
continue
}
}
freshObservations = append(freshObservations, obs)
}
// Cluster similar observations to remove duplicates
clusteredObservations := clusterObservations(freshObservations, 0.4)
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
// Record retrieval stats with staleness metrics
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false)
// Increment retrieval counts for scoring (async, non-blocking)
if len(clusteredObservations) > 0 {
ids := make([]int64, len(clusteredObservations))
for i, obs := range clusteredObservations {
ids[i] = obs.ID
}
s.incrementRetrievalCounts(ids)
}
log.Info().
Str("project", project).
Int("total", len(observations)).
Int("fresh", len(freshObservations)).
Int("clustered", len(clusteredObservations)).
Int("duplicates", duplicatesRemoved).
Int("stale_excluded", staleCount).
Msg("Context injection with clustering")
writeJSON(w, map[string]any{
"project": project,
"observations": clusteredObservations,
"full_count": fullCount,
"stale_excluded": staleCount,
"duplicates_removed": duplicatesRemoved,
})
}
// handleContextCount returns the count of observations for a project.
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
if project == "" {
http.Error(w, "project required", http.StatusBadRequest)
return
}
count, err := s.getCachedObservationCount(r.Context(), project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]any{
"project": project,
"count": count,
})
}
+595
View File
@@ -0,0 +1,595 @@
// Package worker provides data retrieval HTTP handlers.
package worker
import (
"encoding/json"
"net/http"
"runtime"
"sort"
"strings"
"time"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// handleGetObservations returns recent observations.
// Supports optional query parameter for semantic search via sqlite-vec.
// Supports pagination via limit and offset query parameters.
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var observations []*models.Observation
var total int64
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
if len(obsIDs) > 0 {
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
if err == nil {
usedVector = true
total = int64(len(observations)) // Vector search doesn't have total, use returned count
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
// Strict project filtering for dashboard - only observations from this project
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
} else {
// All projects
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if observations == nil {
observations = []*models.Observation{}
}
// Track search if query was provided
if query != "" {
s.trackSearchQuery(query, project, "observations", len(observations), usedVector)
}
// Return paginated response
writeJSON(w, map[string]any{
"observations": observations,
"total": total,
"limit": pagination.Limit,
"offset": pagination.Offset,
"hasMore": int64(pagination.Offset)+int64(len(observations)) < total,
})
}
// handleGetSummaries returns recent summaries.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var summaries []*models.SessionSummary
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
if len(summaryIDs) > 0 {
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
} else {
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if summaries == nil {
summaries = []*models.SessionSummary{}
}
writeJSON(w, summaries)
}
// handleGetPrompts returns recent user prompts.
// Supports optional query parameter for semantic search via sqlite-vec.
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
project := r.URL.Query().Get("project")
query := r.URL.Query().Get("query")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var prompts []*models.UserPromptWithSession
var err error
var usedVector bool
// Use vector search if query is provided and vector client is available
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
if vecErr == nil && len(vectorResults) > 0 {
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
if len(promptIDs) > 0 {
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
if err == nil {
usedVector = true
}
}
}
}
// Fall back to SQLite if vector search not used
if !usedVector {
if project != "" {
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
} else {
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
}
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Ensure we return empty array, not null
if prompts == nil {
prompts = []*models.UserPromptWithSession{}
}
writeJSON(w, prompts)
}
// handleGetProjects returns all projects.
// Response is cacheable for 5 minutes since project list changes infrequently.
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
projects, err := s.sessionStore.GetAllProjects(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Cache for 5 minutes - project list changes infrequently
w.Header().Set("Cache-Control", "public, max-age=300")
writeJSON(w, projects)
}
// handleGetTypes returns the canonical list of observation and concept types.
// This provides a single source of truth for both backend and frontend.
// Response is cacheable as these values never change at runtime.
func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
// Cache for 24 hours - these values are compile-time constants
w.Header().Set("Cache-Control", "public, max-age=86400")
writeJSON(w, map[string]any{
"observation_types": ObservationTypes,
"concept_types": ConceptTypes,
})
}
// handleGetModels returns available embedding models.
// Response is cacheable as model list doesn't change without restart.
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
// Cache for 1 hour - model list is static during runtime
w.Header().Set("Cache-Control", "public, max-age=3600")
models := embedding.ListModels()
defaultModel := embedding.GetDefaultModel()
writeJSON(w, map[string]any{
"models": models,
"default": defaultModel,
"current": s.embedSvc.Version(),
})
}
// handleGetStats returns worker statistics.
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
// Validate project name to prevent path traversal
if err := ValidateProjectName(project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
retrievalStats := s.GetRetrievalStats(project)
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
response := map[string]any{
"uptime": time.Since(s.startTime).String(),
"uptimeSeconds": time.Since(s.startTime).Seconds(),
"activeSessions": s.sessionManager.GetActiveSessionCount(),
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
"connectedClients": s.sseBroadcaster.ClientCount(),
"sessionsToday": sessionsToday,
"retrieval": retrievalStats,
"ready": s.ready.Load(),
}
// Add memory stats
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
response["memory"] = map[string]any{
"alloc_mb": float64(memStats.Alloc) / 1024 / 1024,
"total_alloc_mb": float64(memStats.TotalAlloc) / 1024 / 1024,
"sys_mb": float64(memStats.Sys) / 1024 / 1024,
"heap_alloc_mb": float64(memStats.HeapAlloc) / 1024 / 1024,
"heap_inuse_mb": float64(memStats.HeapInuse) / 1024 / 1024,
"heap_objects": memStats.HeapObjects,
"goroutines": runtime.NumGoroutine(),
"gc_cycles": memStats.NumGC,
"gc_pause_total_ms": float64(memStats.PauseTotalNs) / 1e6,
}
// Add database health if available
if s.store != nil {
dbHealth := s.store.HealthCheck(r.Context())
response["database"] = map[string]any{
"status": dbHealth.Status,
"query_latency_ms": float64(dbHealth.QueryLatency) / 1e6,
"pool": dbHealth.PoolStats,
"warning": dbHealth.Warning,
}
}
// Add embedding model info
if s.embedSvc != nil {
response["embeddingModel"] = map[string]any{
"name": s.embedSvc.Name(),
"version": s.embedSvc.Version(),
"dimensions": s.embedSvc.Dimensions(),
}
}
// Add vector cache stats
if s.vectorClient != nil {
if count, err := s.vectorClient.Count(r.Context()); err == nil {
response["vectorCount"] = count
}
cacheSize, cacheMax := s.vectorClient.CacheStats()
response["vectorCache"] = map[string]any{
"size": cacheSize,
"max_size": cacheMax,
}
}
// Include project-specific observation count if project is specified
if project != "" {
count, err := s.getCachedObservationCount(r.Context(), project)
if err == nil {
response["projectObservations"] = count
response["project"] = project
}
}
// Add rate limiter stats
if s.rateLimiter != nil {
response["rateLimiter"] = s.rateLimiter.Stats()
}
// Add circuit breaker metrics
if s.processor != nil {
response["circuitBreaker"] = s.processor.CircuitBreakerMetrics()
}
writeJSON(w, response)
}
// handleGetRetrievalStats returns detailed retrieval statistics.
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
stats := s.GetRetrievalStats(project)
writeJSON(w, stats)
}
// handleGetRecentQueries returns recent search queries for analytics.
func (s *Service) handleGetRecentQueries(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
limit := gorm.ParseLimitParam(r, 20)
queries := s.getRecentSearchQueries(project, limit)
writeJSON(w, map[string]any{
"queries": queries,
"count": len(queries),
"project": project,
})
}
// handleGetSearchAnalytics returns comprehensive search analytics and statistics.
func (s *Service) handleGetSearchAnalytics(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
// Get all recent queries for analysis
queries := s.getRecentSearchQueries(project, maxRecentQueries)
// Calculate analytics
totalQueries := len(queries)
vectorSearches := 0
totalResults := 0
zeroResultQueries := 0
queryTypes := make(map[string]int)
topKeywords := make(map[string]int)
for _, q := range queries {
if q.UsedVector {
vectorSearches++
}
totalResults += q.Results
if q.Results == 0 {
zeroResultQueries++
}
queryTypes[q.Type]++
// Extract keywords (simple word tokenization using iterator)
for word := range strings.FieldsSeq(strings.ToLower(q.Query)) {
if len(word) > 3 { // Skip short words
topKeywords[word]++
}
}
}
// Sort keywords by frequency
type keywordCount struct {
Keyword string `json:"keyword"`
Count int `json:"count"`
}
sortedKeywords := make([]keywordCount, 0, len(topKeywords))
for kw, count := range topKeywords {
sortedKeywords = append(sortedKeywords, keywordCount{Keyword: kw, Count: count})
}
sort.Slice(sortedKeywords, func(i, j int) bool {
return sortedKeywords[i].Count > sortedKeywords[j].Count
})
if len(sortedKeywords) > 10 {
sortedKeywords = sortedKeywords[:10]
}
// Calculate averages
avgResults := float64(0)
vectorSearchRate := float64(0)
zeroResultRate := float64(0)
if totalQueries > 0 {
avgResults = float64(totalResults) / float64(totalQueries)
vectorSearchRate = float64(vectorSearches) / float64(totalQueries) * 100
zeroResultRate = float64(zeroResultQueries) / float64(totalQueries) * 100
}
writeJSON(w, map[string]any{
"total_queries": totalQueries,
"vector_search_rate": vectorSearchRate,
"avg_results": avgResults,
"zero_result_rate": zeroResultRate,
"query_types": queryTypes,
"top_keywords": sortedKeywords,
"project": project,
})
}
// handleVectorHealth returns comprehensive health information about the vector database.
func (s *Service) handleVectorHealth(w http.ResponseWriter, r *http.Request) {
if s.vectorClient == nil {
http.Error(w, "vector client not initialized", http.StatusServiceUnavailable)
return
}
stats, err := s.vectorClient.GetHealthStats(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Add additional computed metrics
healthScore := 100.0
var warnings []string
// Penalize for stale vectors
if stats.TotalVectors > 0 {
staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors)
if staleRatio > 0 {
healthScore -= staleRatio * 50 // Up to 50 points off for stale vectors
warnings = append(warnings, formatWarning("%.1f%% vectors need rebuild", staleRatio*100))
}
}
// Check cache effectiveness
cacheHitRate := stats.EmbeddingCache.HitRate()
if cacheHitRate < 20 && (stats.EmbeddingCache.EmbeddingHits+stats.EmbeddingCache.EmbeddingMisses) > 100 {
healthScore -= 10
warnings = append(warnings, formatWarning("Low cache hit rate: %.1f%%", cacheHitRate))
}
// Penalize if rebuild is needed
if stats.NeedsRebuild {
healthScore -= 20
warnings = append(warnings, "Vector rebuild recommended: "+stats.RebuildReason)
}
if healthScore < 0 {
healthScore = 0
}
status := "healthy"
if healthScore < 50 {
status = "unhealthy"
} else if healthScore < 80 {
status = "degraded"
}
writeJSON(w, map[string]any{
"status": status,
"health_score": healthScore,
"warnings": warnings,
"stats": stats,
"cache_hit_rate": cacheHitRate,
})
}
// UpdateObservationRequest is the request body for updating an observation.
type UpdateObservationRequest struct {
Title *string `json:"title,omitempty"`
Subtitle *string `json:"subtitle,omitempty"`
Narrative *string `json:"narrative,omitempty"`
Scope *string `json:"scope,omitempty"`
Facts []string `json:"facts,omitempty"`
Concepts []string `json:"concepts,omitempty"`
FilesRead []string `json:"files_read,omitempty"`
FilesModified []string `json:"files_modified,omitempty"`
}
// handleUpdateObservation updates an existing observation.
// PUT /api/observations/{id}
func (s *Service) handleUpdateObservation(w http.ResponseWriter, r *http.Request) {
// Parse observation ID from URL
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
if !ok {
return
}
// Parse request body
var req UpdateObservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
// Build update struct - only include fields that were provided
update := &gorm.ObservationUpdate{}
if req.Title != nil {
update.Title = req.Title
}
if req.Subtitle != nil {
update.Subtitle = req.Subtitle
}
if req.Narrative != nil {
update.Narrative = req.Narrative
}
if req.Facts != nil {
update.Facts = &req.Facts
}
if req.Concepts != nil {
update.Concepts = &req.Concepts
}
if req.FilesRead != nil {
update.FilesRead = &req.FilesRead
}
if req.FilesModified != nil {
update.FilesModified = &req.FilesModified
}
if req.Scope != nil {
// Validate scope
if *req.Scope != "project" && *req.Scope != "global" {
http.Error(w, "scope must be 'project' or 'global'", http.StatusBadRequest)
return
}
update.Scope = req.Scope
}
// Update the observation
updatedObs, err := s.observationStore.UpdateObservation(r.Context(), id, update)
if err != nil {
if strings.Contains(err.Error(), "not found") {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, "failed to update observation: "+err.Error(), http.StatusInternalServerError)
return
}
// Trigger vector resync for the updated observation
if s.vectorSync != nil {
s.asyncVectorSync(func() {
if err := s.vectorSync.SyncObservation(s.ctx, updatedObs); err != nil {
log.Warn().Err(err).Int64("id", id).Msg("Failed to resync observation vectors after update")
}
})
}
// Broadcast update event
s.sseBroadcaster.Broadcast(map[string]any{
"type": "observation_updated",
"id": id,
})
writeJSON(w, map[string]any{
"observation": updatedObs,
"message": "observation updated successfully",
})
}
// handleGetObservationByID returns a single observation by ID.
// GET /api/observations/{id}
func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Request) {
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
if !ok {
return
}
obs, err := s.observationStore.GetObservationByID(r.Context(), id)
if err != nil {
http.Error(w, "failed to get observation: "+err.Error(), http.StatusInternalServerError)
return
}
if obs == nil {
http.Error(w, "observation not found", http.StatusNotFound)
return
}
writeJSON(w, obs)
}
+680
View File
@@ -0,0 +1,680 @@
// Package worker provides import, export, and archive HTTP handlers.
package worker
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
"github.com/rs/zerolog/log"
)
// BulkImportRequest is the request body for bulk observation import.
type BulkImportRequest struct {
Project string `json:"project"`
Observations []BulkObservationInput `json:"observations"`
}
// BulkObservationInput represents a single observation in bulk import.
type BulkObservationInput struct {
Type string `json:"type"`
Title string `json:"title"`
Subtitle string `json:"subtitle,omitempty"`
Narrative string `json:"narrative,omitempty"`
Scope string `json:"scope,omitempty"`
Facts []string `json:"facts,omitempty"`
Concepts []string `json:"concepts,omitempty"`
FilesRead []string `json:"files_read,omitempty"`
FilesModified []string `json:"files_modified,omitempty"`
}
// BulkImportResponse contains the result of a bulk import operation.
type BulkImportResponse struct {
Errors []string `json:"errors,omitempty"`
Imported int `json:"imported"`
Failed int `json:"failed"`
SkippedDuplicates int `json:"skipped_duplicates,omitempty"`
}
// handleBulkImport handles bulk import of observations.
// This is useful for migrating data or importing observations from external sources.
func (s *Service) handleBulkImport(w http.ResponseWriter, r *http.Request) {
// Rate limit bulk operations to prevent DoS
if s.bulkOpLimiter != nil && !s.bulkOpLimiter.CanExecute() {
remaining := s.bulkOpLimiter.CooldownRemaining()
http.Error(w, fmt.Sprintf("bulk import rate limited, retry in %d seconds", remaining), http.StatusTooManyRequests)
return
}
var req BulkImportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
if req.Project == "" {
http.Error(w, "project is required", http.StatusBadRequest)
return
}
// Validate project name to prevent path traversal
if err := ValidateProjectName(req.Project); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(req.Observations) == 0 {
http.Error(w, "at least one observation is required", http.StatusBadRequest)
return
}
// Limit batch size to prevent overwhelming the system
maxBatchSize := 100
if len(req.Observations) > maxBatchSize {
http.Error(w, fmt.Sprintf("batch size exceeds maximum of %d", maxBatchSize), http.StatusBadRequest)
return
}
// Create a synthetic session for bulk import
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), fmt.Sprintf("bulk-import-%d", time.Now().UnixMilli()), req.Project, "bulk import")
if err != nil {
http.Error(w, "failed to create import session: "+err.Error(), http.StatusInternalServerError)
return
}
var imported, failed, skippedDupes int
var errors []string
// Track imported observations for deduplication within the batch
importedObs := make([]*models.Observation, 0, len(req.Observations))
// Deduplication threshold - observations more similar than this are considered duplicates
const dedupThreshold = 0.7
for i, obsInput := range req.Observations {
// Validate observation type using O(1) map lookup
if !IsValidObservationType(obsInput.Type) {
failed++
errors = append(errors, fmt.Sprintf("observation %d: invalid type '%s'", i, obsInput.Type))
continue
}
// Build parsed observation
parsedObs := &models.ParsedObservation{
Type: models.ObservationType(obsInput.Type),
Title: obsInput.Title,
Subtitle: obsInput.Subtitle,
Facts: obsInput.Facts,
Narrative: obsInput.Narrative,
Concepts: obsInput.Concepts,
FilesRead: obsInput.FilesRead,
FilesModified: obsInput.FilesModified,
Scope: models.ObservationScope(obsInput.Scope),
}
// Convert to temporary observation for similarity check
tempObs := &models.Observation{
Title: sql.NullString{String: parsedObs.Title, Valid: parsedObs.Title != ""},
Subtitle: sql.NullString{String: parsedObs.Subtitle, Valid: parsedObs.Subtitle != ""},
Narrative: sql.NullString{String: parsedObs.Narrative, Valid: parsedObs.Narrative != ""},
}
// Check for duplicates within this import batch
if similarity.IsSimilarToAny(tempObs, importedObs, dedupThreshold) {
skippedDupes++
continue
}
// Store observation
obsID, _, err := s.observationStore.StoreObservation(
r.Context(),
fmt.Sprintf("bulk-import-%d", sessionID),
req.Project,
parsedObs,
0, // prompt number
0, // discovery tokens
)
if err != nil {
failed++
errors = append(errors, fmt.Sprintf("observation %d: %v", i, err))
continue
}
// Sync to vector DB asynchronously with rate limiting
if s.vectorSync != nil {
s.asyncVectorSync(func() {
// Use service context as parent to respect shutdown signals
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
defer cancel()
obs, err := s.observationStore.GetObservationByID(ctx, obsID)
if err == nil && obs != nil {
if syncErr := s.vectorSync.SyncObservation(ctx, obs); syncErr != nil {
if s.ctx.Err() == nil { // Don't log during shutdown
log.Debug().Err(syncErr).Int64("id", obsID).Msg("Failed to sync observation during bulk import")
}
}
}
})
}
// Track for deduplication of subsequent observations in this batch
importedObs = append(importedObs, tempObs)
imported++
}
log.Info().
Str("project", req.Project).
Int("imported", imported).
Int("failed", failed).
Int("skipped_duplicates", skippedDupes).
Msg("Bulk import completed")
// Invalidate observation count cache after import
if imported > 0 {
if req.Project != "" {
s.invalidateObsCountCache(req.Project)
} else {
s.invalidateAllObsCountCache()
}
}
// Broadcast observation event for dashboard refresh
s.sseBroadcaster.Broadcast(map[string]any{
"type": "observation",
"action": "bulk_import",
"project": req.Project,
"count": imported,
})
writeJSON(w, BulkImportResponse{
Imported: imported,
Failed: failed,
SkippedDuplicates: skippedDupes,
Errors: errors,
})
}
// ArchiveRequest is the request body for archiving observations.
type ArchiveRequest struct {
Project string `json:"project,omitempty"`
Reason string `json:"reason,omitempty"`
IDs []int64 `json:"ids,omitempty"`
MaxAgeDays int `json:"max_age_days,omitempty"`
}
// handleArchiveObservations archives observations by ID or by age.
// Supports batch archival with error tracking per observation.
func (s *Service) handleArchiveObservations(w http.ResponseWriter, r *http.Request) {
var req ArchiveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
var archivedIDs []int64
var failedIDs []int64
var errors []string
var err error
if len(req.IDs) > 0 {
// Archive specific observations with parallel processing for large batches
if len(req.IDs) > 5 {
// Use parallel archival for batches larger than 5
type archiveResult struct {
err error
id int64
}
results := make(chan archiveResult, len(req.IDs))
// Limit concurrency to avoid overwhelming the database
sem := make(chan struct{}, 5)
var wg sync.WaitGroup
for _, id := range req.IDs {
wg.Add(1)
go func(obsID int64) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
archErr := s.observationStore.ArchiveObservation(r.Context(), obsID, req.Reason)
results <- archiveResult{id: obsID, err: archErr}
}(id)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
for res := range results {
if res.err != nil {
log.Warn().Err(res.err).Int64("id", res.id).Msg("Failed to archive observation")
failedIDs = append(failedIDs, res.id)
errors = append(errors, fmt.Sprintf("id %d: %v", res.id, res.err))
} else {
archivedIDs = append(archivedIDs, res.id)
}
}
} else {
// Sequential for small batches
for _, id := range req.IDs {
if archErr := s.observationStore.ArchiveObservation(r.Context(), id, req.Reason); archErr != nil {
log.Warn().Err(archErr).Int64("id", id).Msg("Failed to archive observation")
failedIDs = append(failedIDs, id)
errors = append(errors, fmt.Sprintf("id %d: %v", id, archErr))
} else {
archivedIDs = append(archivedIDs, id)
}
}
}
} else if req.Project != "" || req.MaxAgeDays > 0 {
// Archive by age
archivedIDs, err = s.observationStore.ArchiveOldObservations(r.Context(), req.Project, req.MaxAgeDays, req.Reason)
if err != nil {
http.Error(w, "failed to archive: "+err.Error(), http.StatusInternalServerError)
return
}
} else {
http.Error(w, "either 'ids' or 'project'/'max_age_days' is required", http.StatusBadRequest)
return
}
log.Info().
Str("project", req.Project).
Int("archived", len(archivedIDs)).
Int("failed", len(failedIDs)).
Msg("Observations archived")
// Invalidate cache if any observations were archived
if len(archivedIDs) > 0 {
if req.Project != "" {
s.invalidateObsCountCache(req.Project)
} else {
s.invalidateAllObsCountCache()
}
}
response := map[string]any{
"archived_count": len(archivedIDs),
"archived_ids": archivedIDs,
}
if len(failedIDs) > 0 {
response["failed_count"] = len(failedIDs)
response["failed_ids"] = failedIDs
response["errors"] = errors
}
writeJSON(w, response)
}
// handleUnarchiveObservation restores an archived observation.
func (s *Service) handleUnarchiveObservation(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid observation id", http.StatusBadRequest)
return
}
if err := s.observationStore.UnarchiveObservation(r.Context(), id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Invalidate all caches since we don't know the project
s.invalidateAllObsCountCache()
writeJSON(w, map[string]any{
"success": true,
"id": id,
})
}
// handleGetArchivedObservations returns archived observations.
func (s *Service) handleGetArchivedObservations(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
limit := gorm.ParseLimitParam(r, DefaultObservationsLimit)
observations, err := s.observationStore.GetArchivedObservations(r.Context(), project, limit)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if observations == nil {
observations = []*models.Observation{}
}
writeJSON(w, observations)
}
// handleGetArchivalStats returns archival statistics.
func (s *Service) handleGetArchivalStats(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
stats, err := s.observationStore.GetArchivalStats(r.Context(), project)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, stats)
}
// handleExportObservations exports observations in JSON or CSV format.
// Supports query parameters: project, format (json/csv), scope, type, limit.
func (s *Service) handleExportObservations(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
format := r.URL.Query().Get("format")
if format == "" {
format = "json"
}
scope := r.URL.Query().Get("scope") // project, global, or empty for all
obsType := r.URL.Query().Get("type") // bugfix, feature, etc.
limit := gorm.ParseLimitParamWithMax(r, 1000, 5000) // Higher limit for exports, capped at 5000
// Validate format
if format != "json" && format != "csv" {
http.Error(w, "format must be 'json' or 'csv'", http.StatusBadRequest)
return
}
// Get observations with filters
ctx := r.Context()
var observations []*models.Observation
var err error
if project != "" {
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
} else {
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Apply additional filters
if scope != "" || obsType != "" {
filtered := make([]*models.Observation, 0, len(observations))
for _, obs := range observations {
if scope != "" && string(obs.Scope) != scope {
continue
}
if obsType != "" && string(obs.Type) != obsType {
continue
}
filtered = append(filtered, obs)
}
observations = filtered
}
// Generate filename
timestamp := time.Now().Format("20060102-150405")
filename := fmt.Sprintf("observations-%s.%s", timestamp, format)
if project != "" {
// Sanitize project name for filename
sanitized := strings.ReplaceAll(project, "/", "_")
sanitized = strings.ReplaceAll(sanitized, "\\", "_")
if len(sanitized) > 50 {
sanitized = sanitized[:50]
}
filename = fmt.Sprintf("observations-%s-%s.%s", sanitized, timestamp, format)
}
switch format {
case "csv":
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
s.writeObservationsCSV(w, observations)
default: // json
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
writeJSON(w, map[string]any{
"exported_at": time.Now().Format(time.RFC3339),
"project": project,
"count": len(observations),
"observations": observations,
})
}
}
// writeObservationsCSV writes observations in CSV format.
// Uses fmt.Fprintf directly to avoid intermediate string allocations.
func (s *Service) writeObservationsCSV(w http.ResponseWriter, observations []*models.Observation) {
// Write CSV header
_, _ = io.WriteString(w, "id,type,scope,project,title,subtitle,narrative,concepts,facts,created_at,importance_score\n")
for _, obs := range observations {
// Write directly to avoid string allocation per row
_, _ = fmt.Fprintf(w, "%d,%s,%s,%s,%s,%s,%s,%s,%s,%s,%.2f\n",
obs.ID,
obs.Type,
obs.Scope,
escapeCsvField(obs.Project),
escapeCsvField(obs.Title.String),
escapeCsvField(obs.Subtitle.String),
escapeCsvField(obs.Narrative.String),
escapeCsvField(strings.Join(obs.Concepts, ";")),
escapeCsvField(strings.Join(obs.Facts, ";")),
obs.CreatedAt,
obs.ImportanceScore,
)
}
}
// escapeCsvField escapes a field for CSV output.
func escapeCsvField(s string) string {
// If field contains comma, quote, or newline, wrap in quotes and escape quotes
if strings.ContainsAny(s, ",\"\n\r") {
s = strings.ReplaceAll(s, "\"", "\"\"")
return "\"" + s + "\""
}
return s
}
// BulkStatusRequest represents a request to update status for multiple observations.
type BulkStatusRequest struct {
Action string `json:"action"`
Reason string `json:"reason,omitempty"`
IDs []int64 `json:"ids"`
Feedback int `json:"feedback,omitempty"`
}
// handleBulkStatusUpdate updates status for multiple observations in one request.
func (s *Service) handleBulkStatusUpdate(w http.ResponseWriter, r *http.Request) {
var req BulkStatusRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
return
}
if len(req.IDs) == 0 {
http.Error(w, "ids is required", http.StatusBadRequest)
return
}
if len(req.IDs) > 500 {
http.Error(w, "maximum 500 ids per request", http.StatusBadRequest)
return
}
ctx := r.Context()
var updated, failed int
var errors []string
switch req.Action {
case "supersede":
for _, id := range req.IDs {
if err := s.observationStore.MarkAsSuperseded(ctx, id); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
case "archive":
for _, id := range req.IDs {
if err := s.observationStore.ArchiveObservation(ctx, id, req.Reason); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
case "set_feedback":
if req.Feedback < -1 || req.Feedback > 1 {
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
return
}
for _, id := range req.IDs {
if err := s.observationStore.UpdateObservationFeedback(ctx, id, req.Feedback); err != nil {
failed++
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
} else {
updated++
}
}
default:
http.Error(w, "action must be 'supersede', 'archive', or 'set_feedback'", http.StatusBadRequest)
return
}
// Invalidate cache for archive action (affects observation counts)
if req.Action == "archive" && updated > 0 {
// No project info available, invalidate all caches
s.invalidateAllObsCountCache()
}
response := map[string]any{
"action": req.Action,
"updated": updated,
"failed": failed,
}
if len(errors) > 0 {
response["errors"] = errors
}
writeJSON(w, response)
}
// handleFindDuplicates finds potential duplicate observations using similarity clustering.
// Returns groups of similar observations that may be candidates for merging or archival.
func (s *Service) handleFindDuplicates(w http.ResponseWriter, r *http.Request) {
project := r.URL.Query().Get("project")
thresholdStr := r.URL.Query().Get("threshold")
limit := gorm.ParseLimitParam(r, 100)
// Parse threshold (default 0.6 = 60% similarity)
threshold := 0.6
if thresholdStr != "" {
if t, err := strconv.ParseFloat(thresholdStr, 64); err == nil && t > 0 && t < 1 {
threshold = t
}
}
// Get recent observations
ctx := r.Context()
var observations []*models.Observation
var err error
if project != "" {
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
} else {
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if len(observations) < 2 {
writeJSON(w, map[string]any{
"duplicate_groups": []any{},
"total_checked": len(observations),
"threshold": threshold,
})
return
}
// Find duplicates using similarity comparison
type duplicateGroup struct {
Observations []map[string]any `json:"observations"`
Similarity float64 `json:"similarity"`
}
groups := []duplicateGroup{}
processed := make(map[int64]bool)
for i, obs1 := range observations {
if processed[obs1.ID] {
continue
}
terms1 := similarity.ExtractObservationTerms(obs1)
if len(terms1) == 0 {
continue
}
group := duplicateGroup{
Observations: []map[string]any{obs1.ToMap()},
Similarity: 1.0,
}
for j := i + 1; j < len(observations); j++ {
obs2 := observations[j]
if processed[obs2.ID] {
continue
}
terms2 := similarity.ExtractObservationTerms(obs2)
sim := similarity.JaccardSimilarity(terms1, terms2)
if sim >= threshold {
obsMap := obs2.ToMap()
obsMap["similarity_to_first"] = sim
group.Observations = append(group.Observations, obsMap)
group.Similarity = min(group.Similarity, sim)
processed[obs2.ID] = true
}
}
if len(group.Observations) > 1 {
processed[obs1.ID] = true
groups = append(groups, group)
}
}
// Sort groups by size (largest first)
sort.Slice(groups, func(i, j int) bool {
return len(groups[i].Observations) > len(groups[j].Observations)
})
writeJSON(w, map[string]any{
"duplicate_groups": groups,
"total_checked": len(observations),
"groups_found": len(groups),
"threshold": threshold,
"project": project,
})
}
+9 -5
View File
@@ -10,6 +10,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// FeedbackRequest represents a user feedback submission.
@@ -311,8 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
// Run recalculation in background
go func() {
if err := recalculator.RecalculateNow(r.Context()); err != nil {
// Log error but don't block response
_ = err // Explicitly ignore - background operation
log.Warn().Err(err).Msg("Background score recalculation failed")
}
}()
@@ -345,14 +345,18 @@ func (s *Service) incrementRetrievalCounts(ids []int64) {
}
// Increment in background to not block response
// Use service context to respect shutdown signals
s.wg.Add(1)
go func() {
// Create a new context with timeout for the background operation
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer s.wg.Done()
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
// Log but don't fail - this is a background operation
_ = err // Explicitly ignore - background operation
if s.ctx.Err() == nil { // Don't log during shutdown
log.Debug().Err(err).Msg("Failed to increment retrieval counts")
}
}
}()
}
+354
View File
@@ -0,0 +1,354 @@
// Package worker provides session-related HTTP handlers.
package worker
import (
"context"
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
"github.com/rs/zerolog/log"
)
// SessionInitRequest is the request body for session initialization.
type SessionInitRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
Prompt string `json:"prompt"`
MatchedObservations int `json:"matchedObservations"`
}
// SessionInitResponse is the response for session initialization.
type SessionInitResponse struct {
Reason string `json:"reason,omitempty"`
SessionDBID int64 `json:"sessionDbId"`
PromptNumber int `json:"promptNumber"`
Skipped bool `json:"skipped,omitempty"`
}
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
// If the same prompt text is seen within this window, it's considered a duplicate hook invocation.
const DuplicatePromptWindowSeconds = 10
// handleSessionInit handles session initialization from user-prompt hook.
// This handler is idempotent - duplicate requests within a short time window
// return the existing prompt data without creating duplicates.
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
var req SessionInitRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Privacy check
if privacy.IsEntirelyPrivate(req.Prompt) {
// Create session but skip processing
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
Skipped: true,
Reason: "private",
})
return
}
// Clean prompt
cleanedPrompt := privacy.Clean(req.Prompt)
// DUPLICATE DETECTION: Check if this exact prompt was already saved recently.
// This prevents the bug where the hook fires multiple times for the same user action,
// creating many duplicate prompts with incrementing numbers.
if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found {
// Get or create session (idempotent)
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
log.Debug().
Int64("sessionId", sessionID).
Int("promptNumber", existingNum).
Int64("promptId", existingID).
Msg("Duplicate prompt detected - returning existing")
// Return existing prompt data without incrementing or saving again
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: existingNum,
})
return
}
// Create session (idempotent)
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Increment prompt counter
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Save user prompt with matched observation count
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
if err != nil {
log.Warn().Err(err).Msg("Failed to save user prompt")
// Non-fatal: continue with session initialization
} else if s.vectorSync != nil {
// Sync to vector DB asynchronously (non-blocking)
now := time.Now()
promptWithSession := &models.UserPromptWithSession{
UserPrompt: models.UserPrompt{
ID: promptID,
ClaudeSessionID: req.ClaudeSessionID,
PromptNumber: promptNum,
PromptText: cleanedPrompt,
MatchedObservations: req.MatchedObservations,
CreatedAt: now.Format(time.RFC3339),
CreatedAtEpoch: now.UnixMilli(),
},
Project: req.Project,
SDKSessionID: req.ClaudeSessionID,
}
s.asyncVectorSync(func() {
// Use service context as parent to respect shutdown signals
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
defer cancel()
if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
if s.ctx.Err() == nil { // Don't log during shutdown
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec")
}
}
})
}
log.Info().
Int64("sessionId", sessionID).
Int("promptNumber", promptNum).
Str("project", req.Project).
Msg("Session initialized")
// Broadcast prompt event for dashboard refresh
s.sseBroadcaster.Broadcast(map[string]any{
"type": "prompt",
"action": "created",
"project": req.Project,
})
writeJSON(w, SessionInitResponse{
SessionDBID: sessionID,
PromptNumber: promptNum,
})
}
// SessionStartRequest is the request body for starting SDK agent.
type SessionStartRequest struct {
UserPrompt string `json:"userPrompt"`
PromptNumber int `json:"promptNumber"`
}
// handleSessionStart handles SDK agent session start.
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SessionStartRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Initialize session in manager
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
// Session is now registered. Observations will be processed
// asynchronously by the background queue processor (processQueue in service.go).
log.Info().
Int64("sessionId", id).
Int("promptNumber", req.PromptNumber).
Msg("SDK agent session initialized")
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// ObservationRequest is the request body for posting observations.
type ObservationRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
ToolName string `json:"tool_name"`
ToolInput any `json:"tool_input"`
ToolResponse any `json:"tool_response"`
CWD string `json:"cwd"`
}
// handleObservation handles observation posting from post-tool-use hook.
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
var req ObservationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if sess == nil {
// Create session on-the-fly with project from request
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
}
// Queue observation
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
ToolName: req.ToolName,
ToolInput: req.ToolInput,
ToolResponse: req.ToolResponse,
CWD: req.CWD,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// SubagentCompleteRequest is the request body for subagent completion.
type SubagentCompleteRequest struct {
ClaudeSessionID string `json:"claudeSessionId"`
Project string `json:"project"`
}
// handleSubagentComplete handles subagent/Task completion notifications.
// This triggers immediate processing of any queued observations from the subagent.
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
var req SubagentCompleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Find session
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
if err != nil || sess == nil {
// Session not found - subagent may have been in a different context
log.Debug().
Str("claudeSessionId", req.ClaudeSessionID).
Msg("Subagent complete - no active session found")
w.WriteHeader(http.StatusOK)
return
}
// Trigger immediate processing of queued observations
messages := s.sessionManager.DrainMessages(sess.ID)
if len(messages) > 0 && s.processor != nil {
log.Info().
Int64("sessionId", sess.ID).
Int("messages", len(messages)).
Msg("Processing queued observations from subagent")
for _, msg := range messages {
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
err := s.processor.ProcessObservation(
r.Context(),
sess.SDKSessionID.String,
sess.Project,
msg.Observation.ToolName,
msg.Observation.ToolInput,
msg.Observation.ToolResponse,
msg.Observation.PromptNumber,
msg.Observation.CWD,
)
if err != nil {
log.Error().Err(err).
Str("tool", msg.Observation.ToolName).
Msg("Failed to process subagent observation")
}
}
}
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
// handleGetSessionByClaudeID looks up a session by Claude session ID.
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
claudeSessionID := r.URL.Query().Get("claudeSessionId")
if claudeSessionID == "" {
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
return
}
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if session == nil {
http.Error(w, "session not found", http.StatusNotFound)
return
}
writeJSON(w, session)
}
// SummarizeRequest is the request body for summarize requests.
type SummarizeRequest struct {
LastUserMessage string `json:"lastUserMessage"`
LastAssistantMessage string `json:"lastAssistantMessage"`
}
// handleSummarize handles summarize requests from stop hook.
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid session id", http.StatusBadRequest)
return
}
var req SummarizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Queue summarize request
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.broadcastProcessingStatus()
w.WriteHeader(http.StatusOK)
}
+25 -12
View File
@@ -66,6 +66,7 @@ func testService(t *testing.T) (*Service, func()) {
cancel: cancel,
startTime: time.Now(),
retrievalStats: make(map[string]*RetrievalStats),
cachedObsCounts: make(map[string]cachedCount),
}
svc.setupRoutes()
@@ -345,11 +346,13 @@ func TestHandleGetObservations_Limit(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
// Parse as generic JSON array since the model uses custom marshaling
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
// Parse as object with observations key (API returns wrapped response)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.Len(t, observations, 10)
}
@@ -1135,10 +1138,13 @@ func TestHandleGetObservations_DefaultLimit(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
// Should return default limit (100)
assert.LessOrEqual(t, len(observations), DefaultObservationsLimit)
}
@@ -1159,10 +1165,12 @@ func TestHandleGetObservations_FilterByProject(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.Len(t, observations, 2)
}
@@ -1412,10 +1420,12 @@ func TestHandleGetObservations(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
var observations []map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &observations)
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
observations, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.GreaterOrEqual(t, len(observations), 2)
}
@@ -2697,10 +2707,13 @@ func TestHandleGetObservations_EmptyResult(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
// Should return empty array, not null
var obs []interface{}
err := json.Unmarshal(rec.Body.Bytes(), &obs)
// Should return empty array within observations key, not null
var response map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)
obs, ok := response["observations"].([]interface{})
require.True(t, ok, "expected observations array in response")
assert.NotNil(t, obs)
}
+243
View File
@@ -0,0 +1,243 @@
// Package worker provides update and restart HTTP handlers.
package worker
import (
"fmt"
"net/http"
"time"
"github.com/rs/zerolog/log"
)
// handleUpdateCheck checks for available updates.
func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) {
info, err := s.updater.CheckForUpdate(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, info)
}
// handleUpdateApply downloads and applies an available update.
func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) {
// First check for update
info, err := s.updater.CheckForUpdate(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !info.Available {
writeJSON(w, map[string]any{
"success": false,
"message": "No update available",
})
return
}
// Apply update in background with tracking for graceful shutdown
s.wg.Go(func() {
if err := s.updater.ApplyUpdate(s.ctx, info); err != nil {
log.Error().Err(err).Msg("Update failed")
}
})
writeJSON(w, map[string]any{
"success": true,
"message": "Update started",
"version": info.LatestVersion,
})
}
// handleUpdateStatus returns the current update status.
func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) {
status := s.updater.GetStatus()
writeJSON(w, status)
}
// ComponentHealth represents the health status of a single component.
type ComponentHealth struct {
Name string `json:"name"`
Status string `json:"status"` // "healthy", "degraded", "unhealthy"
Message string `json:"message,omitempty"`
}
// SelfCheckResponse contains the health status of all components.
type SelfCheckResponse struct {
Overall string `json:"overall"` // "healthy", "degraded", "unhealthy"
Version string `json:"version"`
Uptime string `json:"uptime"`
Components []ComponentHealth `json:"components"`
}
// handleSelfCheck returns the health status of all components.
func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
components := []ComponentHealth{}
overall := "healthy"
// Check Worker Service
workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"}
if !s.ready.Load() {
if err := s.GetInitError(); err != nil {
workerStatus.Status = "unhealthy"
workerStatus.Message = err.Error()
overall = "unhealthy"
} else {
workerStatus.Status = "degraded"
workerStatus.Message = "Initializing"
if overall == "healthy" {
overall = "degraded"
}
}
}
components = append(components, workerStatus)
// Check SQLite Database
dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"}
if s.store == nil {
dbStatus.Status = "unhealthy"
dbStatus.Message = "Not initialized"
overall = "unhealthy"
} else if err := s.store.Ping(); err != nil {
dbStatus.Status = "unhealthy"
dbStatus.Message = err.Error()
overall = "unhealthy"
}
components = append(components, dbStatus)
// Check Vector DB (sqlite-vec)
vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"}
if s.vectorClient == nil {
vectorStatus.Status = "degraded"
vectorStatus.Message = "Not configured"
if overall == "healthy" {
overall = "degraded"
}
} else if !s.vectorClient.IsConnected() {
vectorStatus.Status = "degraded"
vectorStatus.Message = "Not connected"
if overall == "healthy" {
overall = "degraded"
}
}
components = append(components, vectorStatus)
// Check SDK Processor
sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"}
if s.processor == nil {
sdkStatus.Status = "degraded"
sdkStatus.Message = "Not initialized"
if overall == "healthy" {
overall = "degraded"
}
} else if !s.processor.IsAvailable() {
sdkStatus.Status = "degraded"
sdkStatus.Message = "Claude CLI not available"
if overall == "healthy" {
overall = "degraded"
}
}
components = append(components, sdkStatus)
// Check SSE Broadcaster
sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"}
if s.sseBroadcaster == nil {
sseStatus.Status = "unhealthy"
sseStatus.Message = "Not initialized"
overall = "unhealthy"
}
components = append(components, sseStatus)
// Check Cross-Encoder Reranker
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
if !s.config.RerankingEnabled {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Disabled in config"
if overall == "healthy" {
overall = "degraded"
}
} else if s.reranker == nil {
rerankerStatus.Status = "degraded"
rerankerStatus.Message = "Not initialized"
if overall == "healthy" {
overall = "degraded"
}
} else {
// Verify reranker is functional using Score
_, normalizedScore, err := s.reranker.Score("test query", "test document")
if err != nil {
rerankerStatus.Status = "unhealthy"
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
if overall == "healthy" {
overall = "degraded"
}
} else {
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
}
}
components = append(components, rerankerStatus)
// Calculate uptime
uptime := time.Since(s.startTime).Round(time.Second).String()
writeJSON(w, SelfCheckResponse{
Overall: overall,
Version: s.version,
Uptime: uptime,
Components: components,
})
}
// handleUpdateRestart restarts the worker with the new binary (after update).
func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) {
status := s.updater.GetStatus()
if status.State != "done" {
http.Error(w, "no update has been applied", http.StatusBadRequest)
return
}
// Send response before restarting
writeJSON(w, map[string]any{
"success": true,
"message": "Restarting worker...",
})
// Flush the response
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Restart in background after response is sent
go func() {
if err := s.updater.Restart(); err != nil {
log.Error().Err(err).Msg("Failed to restart worker")
}
}()
}
// handleRestart restarts the worker process (general restart, not tied to update).
func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
log.Info().Msg("Manual restart requested via API")
// Send response before restarting
writeJSON(w, map[string]any{
"success": true,
"message": "Restarting worker...",
"version": s.version,
})
// Flush the response
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Restart in background after response is sent
go func() {
// Small delay to ensure response is sent
time.Sleep(100 * time.Millisecond)
if err := s.updater.Restart(); err != nil {
log.Error().Err(err).Msg("Failed to restart worker")
}
}()
}
+333
View File
@@ -0,0 +1,333 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
)
// requestIDKey is the context key for request IDs.
type requestIDKey struct{}
// projectNamePattern validates project names to prevent path traversal.
var projectNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./-]+$`)
// allowedOrigins is the whitelist of origins allowed for CORS.
// Uses exact matching to prevent bypass attacks like "evil-localhost.com".
var allowedOrigins = map[string]bool{
"http://localhost": true,
"http://localhost:3000": true,
"http://localhost:5173": true, // Vite dev server
"http://localhost:37778": true, // Dashboard UI
"http://127.0.0.1": true,
"http://127.0.0.1:3000": true,
"http://127.0.0.1:5173": true,
"http://127.0.0.1:37778": true,
}
// SecurityHeaders middleware adds essential security headers to all responses.
// These protect against common web vulnerabilities.
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS filter
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Restrict referrer information
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - restrict to self
w.Header().Set("Content-Security-Policy", "default-src 'self'")
// Permissions Policy - disable unnecessary features
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
// CORS: Use exact match whitelist to prevent bypass attacks
origin := r.Header.Get("Origin")
if allowedOrigins[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Auth-Token, Authorization, X-Request-ID")
}
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// MaxBodySize middleware limits the size of incoming request bodies.
// This prevents denial of service attacks via large payloads.
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > maxBytes {
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
return
}
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
next.ServeHTTP(w, r)
})
}
}
// TokenAuth provides simple token-based authentication for localhost services.
// The token is generated at startup and must be provided in the X-Auth-Token header.
type TokenAuth struct {
ExemptPaths map[string]bool
token string
mu sync.RWMutex
enabled bool
}
// NewTokenAuth creates a new TokenAuth with a randomly generated token.
// If enabled is false, authentication is skipped (useful for development).
func NewTokenAuth(enabled bool) (*TokenAuth, error) {
ta := &TokenAuth{
enabled: enabled,
ExemptPaths: map[string]bool{
"/health": true,
"/api/health": true,
"/api/ready": true,
},
}
if enabled {
// Generate 32-byte random token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return nil, err
}
ta.token = hex.EncodeToString(tokenBytes)
}
return ta, nil
}
// Token returns the authentication token.
// Returns empty string if authentication is disabled.
func (ta *TokenAuth) Token() string {
ta.mu.RLock()
defer ta.mu.RUnlock()
return ta.token
}
// IsEnabled returns whether token authentication is enabled.
func (ta *TokenAuth) IsEnabled() bool {
ta.mu.RLock()
defer ta.mu.RUnlock()
return ta.enabled
}
// Middleware returns HTTP middleware that enforces token authentication.
func (ta *TokenAuth) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ta.mu.RLock()
enabled := ta.enabled
token := ta.token
exempt := ta.ExemptPaths[r.URL.Path]
ta.mu.RUnlock()
// Skip auth if disabled or path is exempt
if !enabled || exempt {
next.ServeHTTP(w, r)
return
}
// Check for token in header
providedToken := r.Header.Get("X-Auth-Token")
if providedToken == "" {
// Also check Authorization header with Bearer scheme
auth := r.Header.Get("Authorization")
if bearer, found := strings.CutPrefix(auth, "Bearer "); found {
providedToken = bearer
}
}
if providedToken != token {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// ExpensiveOperationLimiter provides stricter rate limiting for expensive operations.
// It wraps the base per-client rate limiter with additional per-operation limits.
type ExpensiveOperationLimiter struct {
// Track last execution time per operation type
lastRebuild int64 // Unix timestamp
rebuildCooldown int64 // Minimum seconds between rebuilds
mu sync.Mutex
}
// NewExpensiveOperationLimiter creates a limiter for expensive operations.
func NewExpensiveOperationLimiter() *ExpensiveOperationLimiter {
return &ExpensiveOperationLimiter{
rebuildCooldown: 300, // 5 minutes between rebuilds
}
}
// CanRebuild checks if a vector rebuild operation is allowed.
// Returns false if a rebuild was triggered too recently.
func (eol *ExpensiveOperationLimiter) CanRebuild() bool {
eol.mu.Lock()
defer eol.mu.Unlock()
now := unixNow()
if now-eol.lastRebuild < eol.rebuildCooldown {
return false
}
eol.lastRebuild = now
return true
}
// unixNow returns current Unix timestamp.
// Separated for easier testing.
func unixNow() int64 {
return time.Now().Unix()
}
// RequestID middleware adds a unique request ID to each request.
// The ID is added to the context and response headers for tracing.
func RequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for existing request ID from client
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
// Generate new request ID
idBytes := make([]byte, 8)
if _, err := rand.Read(idBytes); err == nil {
requestID = hex.EncodeToString(idBytes)
} else {
requestID = fmt.Sprintf("%d", time.Now().UnixNano())
}
}
// Add to response header
w.Header().Set("X-Request-ID", requestID)
// Add to context
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetRequestID retrieves the request ID from the context.
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// RequireJSONContentType middleware validates that POST/PUT/PATCH requests
// have application/json Content-Type header.
func RequireJSONContentType(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only check for methods that typically have bodies
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
ct := r.Header.Get("Content-Type")
// Allow empty Content-Type for requests without body
if ct != "" && !strings.HasPrefix(ct, "application/json") {
http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType)
return
}
}
next.ServeHTTP(w, r)
})
}
// ValidateProjectName checks if a project name is safe to use.
// Returns an error if the name contains path traversal or invalid characters.
func ValidateProjectName(project string) error {
if project == "" {
return nil // Empty is allowed (means no filter)
}
// Check for path traversal
if strings.Contains(project, "..") {
return fmt.Errorf("invalid project name: path traversal detected")
}
// Check for valid characters
if !projectNamePattern.MatchString(project) {
return fmt.Errorf("invalid project name: only alphanumeric, underscore, dash, dot, and slash allowed")
}
// Max length check
if len(project) > 500 {
return fmt.Errorf("project name too long (max 500 chars)")
}
return nil
}
// BulkOperationLimiter provides rate limiting for bulk operations.
// Prevents DoS via repeated bulk requests.
type BulkOperationLimiter struct {
lastBulkOp int64 // Unix timestamp
cooldown int64 // Minimum seconds between operations
mu sync.Mutex
}
// NewBulkOperationLimiter creates a limiter for bulk operations.
func NewBulkOperationLimiter(cooldownSeconds int64) *BulkOperationLimiter {
return &BulkOperationLimiter{
cooldown: cooldownSeconds,
}
}
// CanExecute checks if a bulk operation is allowed.
// Returns false if a bulk operation was triggered too recently.
func (bol *BulkOperationLimiter) CanExecute() bool {
bol.mu.Lock()
defer bol.mu.Unlock()
now := unixNow()
if now-bol.lastBulkOp < bol.cooldown {
return false
}
bol.lastBulkOp = now
return true
}
// TimeSinceLastOp returns seconds since the last bulk operation.
func (bol *BulkOperationLimiter) TimeSinceLastOp() int64 {
bol.mu.Lock()
defer bol.mu.Unlock()
return unixNow() - bol.lastBulkOp
}
// CooldownRemaining returns seconds remaining in the cooldown period.
// Returns 0 if no cooldown is active.
func (bol *BulkOperationLimiter) CooldownRemaining() int64 {
bol.mu.Lock()
defer bol.mu.Unlock()
remaining := bol.cooldown - (unixNow() - bol.lastBulkOp)
if remaining < 0 {
return 0
}
return remaining
}
+515
View File
@@ -0,0 +1,515 @@
package worker
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check all security headers are set
tests := []struct {
header string
expected string
}{
{"X-Frame-Options", "DENY"},
{"X-Content-Type-Options", "nosniff"},
{"X-XSS-Protection", "1; mode=block"},
{"Referrer-Policy", "strict-origin-when-cross-origin"},
}
for _, tt := range tests {
if got := rr.Header().Get(tt.header); got != tt.expected {
t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected)
}
}
}
func TestSecurityHeaders_CORS(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
origin string
expectedOrigin string
expectCORS bool
}{
{
name: "localhost:37778 origin allowed",
origin: "http://localhost:37778",
expectCORS: true,
expectedOrigin: "http://localhost:37778",
},
{
name: "127.0.0.1:5173 origin allowed",
origin: "http://127.0.0.1:5173",
expectCORS: true,
expectedOrigin: "http://127.0.0.1:5173",
},
{
name: "localhost without port allowed",
origin: "http://localhost",
expectCORS: true,
expectedOrigin: "http://localhost",
},
{
name: "external origin blocked",
origin: "http://evil.com",
expectCORS: false,
},
{
name: "evil-localhost.com bypass attempt blocked",
origin: "http://evil-localhost.com",
expectCORS: false,
},
{
name: "localhost subdomain bypass attempt blocked",
origin: "http://localhost.evil.com",
expectCORS: false,
},
{
name: "unknown localhost port blocked",
origin: "http://localhost:9999",
expectCORS: false,
},
{
name: "no origin header",
origin: "",
expectCORS: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
cors := rr.Header().Get("Access-Control-Allow-Origin")
if tt.expectCORS {
if cors != tt.expectedOrigin {
t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors)
}
} else {
if cors != "" {
t.Errorf("Expected no CORS header, got %q", cors)
}
}
})
}
}
func TestMaxBodySize(t *testing.T) {
maxSize := int64(100)
handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
contentLength int64
expectedStatus int
}{
{
name: "within limit",
contentLength: 50,
expectedStatus: http.StatusOK,
},
{
name: "at limit",
contentLength: 100,
expectedStatus: http.StatusOK,
},
{
name: "exceeds limit",
contentLength: 150,
expectedStatus: http.StatusRequestEntityTooLarge,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/test", nil)
req.ContentLength = tt.contentLength
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus)
}
})
}
}
func TestTokenAuth(t *testing.T) {
t.Run("disabled auth allows all requests", func(t *testing.T) {
ta, err := NewTokenAuth(false)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with disabled auth, got %d", rr.Code)
}
})
t.Run("enabled auth requires token", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Without token
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("Expected Unauthorized without token, got %d", rr.Code)
}
// With correct token in X-Auth-Token header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Auth-Token", ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with correct token, got %d", rr.Code)
}
// With correct token in Authorization header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with Bearer token, got %d", rr.Code)
}
})
t.Run("exempt paths skip auth", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
exemptPaths := []string{"/health", "/api/health", "/api/ready"}
for _, path := range exemptPaths {
req := httptest.NewRequest("GET", path, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code)
}
}
})
}
func TestExpensiveOperationLimiter(t *testing.T) {
limiter := NewExpensiveOperationLimiter()
// First rebuild should be allowed
if !limiter.CanRebuild() {
t.Error("First rebuild should be allowed")
}
// Immediate second rebuild should be blocked
if limiter.CanRebuild() {
t.Error("Immediate second rebuild should be blocked")
}
}
func TestRequestID(t *testing.T) {
handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request ID is in context
id := GetRequestID(r.Context())
if id == "" {
t.Error("Request ID should be set in context")
}
w.WriteHeader(http.StatusOK)
}))
t.Run("generates new request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") == "" {
t.Error("X-Request-ID header should be set")
}
})
t.Run("uses existing request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Request-ID", "test-id-12345")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") != "test-id-12345" {
t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID"))
}
})
}
func TestRequireJSONContentType(t *testing.T) {
handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
method string
contentType string
expectedStatus int
}{
{
name: "GET request without content-type",
method: "GET",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json",
method: "POST",
contentType: "application/json",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json; charset=utf-8",
method: "POST",
contentType: "application/json; charset=utf-8",
expectedStatus: http.StatusOK,
},
{
name: "POST without content-type (empty body)",
method: "POST",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with text/plain rejected",
method: "POST",
contentType: "text/plain",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PUT with application/xml rejected",
method: "PUT",
contentType: "application/xml",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PATCH with form-urlencoded rejected",
method: "PATCH",
contentType: "application/x-www-form-urlencoded",
expectedStatus: http.StatusUnsupportedMediaType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil)
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
}
})
}
}
func TestValidateProjectName(t *testing.T) {
tests := []struct {
name string
project string
wantError bool
}{
{
name: "empty project allowed",
project: "",
wantError: false,
},
{
name: "simple project name",
project: "my-project",
wantError: false,
},
{
name: "project with path",
project: "org/my-project",
wantError: false,
},
{
name: "project with underscore",
project: "my_project_v2",
wantError: false,
},
{
name: "project with dot",
project: "my.project.name",
wantError: false,
},
{
name: "path traversal attack",
project: "../../../etc/passwd",
wantError: true,
},
{
name: "hidden path traversal",
project: "project/../../secret",
wantError: true,
},
{
name: "shell injection attempt",
project: "project; rm -rf /",
wantError: true,
},
{
name: "backtick injection",
project: "project`whoami`",
wantError: true,
},
{
name: "special characters",
project: "project$HOME",
wantError: true,
},
{
name: "too long project name",
project: string(make([]byte, 501)),
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateProjectName(tt.project)
if tt.wantError && err == nil {
t.Errorf("Expected error for project %q, got nil", tt.project)
}
if !tt.wantError && err != nil {
t.Errorf("Unexpected error for project %q: %v", tt.project, err)
}
})
}
}
func TestBulkOperationLimiter(t *testing.T) {
limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing
// First operation should be allowed
if !limiter.CanExecute() {
t.Error("First bulk operation should be allowed")
}
// Immediate second operation should be blocked
if limiter.CanExecute() {
t.Error("Immediate second bulk operation should be blocked")
}
// Check cooldown remaining
remaining := limiter.CooldownRemaining()
if remaining <= 0 || remaining > 1 {
t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining)
}
// Check time since last op
since := limiter.TimeSinceLastOp()
if since < 0 || since > 1 {
t.Errorf("Expected time since last op between 0-1 seconds, got %d", since)
}
}
func TestSecurityHeaders_CSP(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check CSP header is set
csp := rr.Header().Get("Content-Security-Policy")
if csp == "" {
t.Error("Content-Security-Policy header should be set")
}
if csp != "default-src 'self'" {
t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp)
}
// Check Permissions-Policy header
pp := rr.Header().Get("Permissions-Policy")
if pp == "" {
t.Error("Permissions-Policy header should be set")
}
}
func TestSecurityHeaders_Preflight(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("OPTIONS", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// OPTIONS should return 204 No Content
if rr.Code != http.StatusNoContent {
t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code)
}
// CORS headers should be set for allowed origin
if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
t.Errorf("CORS origin should be set for allowed origin")
}
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Access-Control-Allow-Methods should be set")
}
}
+226
View File
@@ -0,0 +1,226 @@
// Package worker provides the main worker service for claude-mnemonic.
package worker
import (
"net/http"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter.
type RateLimiter struct {
lastUpdate time.Time
rate float64
burst int
tokens float64
requests int64
rejected int64
mu sync.Mutex
}
// LastUpdateTime returns the last update time.
// Thread-safe - acquires the limiter's lock.
func (rl *RateLimiter) LastUpdateTime() time.Time {
rl.mu.Lock()
defer rl.mu.Unlock()
return rl.lastUpdate
}
// lastUpdateTimeUnlocked returns the last update time without locking.
// Caller must hold rl.mu.
func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time {
return rl.lastUpdate
}
// NewRateLimiter creates a new rate limiter.
// rate is the number of requests per second to allow.
// burst is the maximum burst of requests to allow.
func NewRateLimiter(rate float64, burst int) *RateLimiter {
return &RateLimiter{
rate: rate,
burst: burst,
tokens: float64(burst),
lastUpdate: time.Now(),
}
}
// Allow checks if a request should be allowed.
// Returns true if the request is allowed, false if rate limited.
func (rl *RateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.requests++
// Calculate tokens added since last update
now := time.Now()
elapsed := now.Sub(rl.lastUpdate).Seconds()
rl.tokens += elapsed * rl.rate
if rl.tokens > float64(rl.burst) {
rl.tokens = float64(rl.burst)
}
rl.lastUpdate = now
// Check if we have a token available
if rl.tokens >= 1 {
rl.tokens--
return true
}
rl.rejected++
return false
}
// Stats returns rate limiter statistics.
func (rl *RateLimiter) Stats() map[string]any {
rl.mu.Lock()
defer rl.mu.Unlock()
return map[string]any{
"rate": rl.rate,
"burst": rl.burst,
"current_tokens": rl.tokens,
"total_requests": rl.requests,
"rejected": rl.rejected,
"rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1),
}
}
// RateLimitMiddleware creates middleware that applies rate limiting.
// Uses a shared rate limiter for all requests.
func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// PerClientRateLimiter implements per-client rate limiting.
type PerClientRateLimiter struct {
lastCleanup time.Time
clients map[string]*RateLimiter
rate float64
burst int
cleanupInterval time.Duration
maxIdleTime time.Duration
mu sync.Mutex
}
// NewPerClientRateLimiter creates a new per-client rate limiter.
func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter {
return &PerClientRateLimiter{
rate: rate,
burst: burst,
clients: make(map[string]*RateLimiter),
cleanupInterval: 5 * time.Minute,
maxIdleTime: 10 * time.Minute,
lastCleanup: time.Now(),
}
}
// getLimiter returns a rate limiter for the given client key.
func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter {
pcrl.mu.Lock()
defer pcrl.mu.Unlock()
// Periodic cleanup of idle clients
if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval {
pcrl.cleanupLocked()
}
limiter, exists := pcrl.clients[key]
if !exists {
limiter = NewRateLimiter(pcrl.rate, pcrl.burst)
pcrl.clients[key] = limiter
}
return limiter
}
// cleanupLocked removes idle limiters. Must be called with lock held.
// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu.
// This is safe because the limiter.mu critical section is brief (just reading lastUpdate).
func (pcrl *PerClientRateLimiter) cleanupLocked() {
now := time.Now()
keysToDelete := make([]string, 0)
// Check each limiter while holding pcrl.mu
// We briefly acquire limiter.mu but the critical section is minimal
for key, limiter := range pcrl.clients {
limiter.mu.Lock()
lastUpdate := limiter.lastUpdateTimeUnlocked()
limiter.mu.Unlock()
if now.Sub(lastUpdate) > pcrl.maxIdleTime {
keysToDelete = append(keysToDelete, key)
}
}
// Delete collected keys
for _, key := range keysToDelete {
delete(pcrl.clients, key)
}
pcrl.lastCleanup = now
}
// Allow checks if a request from the given client should be allowed.
func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool {
return pcrl.getLimiter(clientKey).Allow()
}
// Stats returns aggregate statistics.
// Uses two-phase approach to avoid nested lock acquisition.
func (pcrl *PerClientRateLimiter) Stats() map[string]any {
// Phase 1: Collect limiters under pcrl.mu
pcrl.mu.Lock()
rate := pcrl.rate
burst := pcrl.burst
activeClients := len(pcrl.clients)
limiters := make([]*RateLimiter, 0, activeClients)
for _, limiter := range pcrl.clients {
limiters = append(limiters, limiter)
}
pcrl.mu.Unlock()
// Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu)
var totalRequests, totalRejected int64
for _, limiter := range limiters {
limiter.mu.Lock()
totalRequests += limiter.requests
totalRejected += limiter.rejected
limiter.mu.Unlock()
}
return map[string]any{
"rate": rate,
"burst": burst,
"active_clients": activeClients,
"total_requests": totalRequests,
"total_rejected": totalRejected,
}
}
// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting.
// Uses X-Forwarded-For or RemoteAddr to identify clients.
func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client identifier (prefer X-Real-IP from RealIP middleware)
clientKey := r.RemoteAddr
if xff := r.Header.Get("X-Real-IP"); xff != "" {
clientKey = xff
}
if !limiter.Allow(clientKey) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
+428 -43
View File
@@ -4,11 +4,15 @@ package sdk
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
json "github.com/goccy/go-json"
@@ -20,8 +24,178 @@ import (
"github.com/rs/zerolog/log"
)
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
type CircuitBreaker struct {
failures int64 // Current failure count
lastFailure int64 // Unix timestamp of last failure
threshold int64 // Number of failures before opening
resetTimeout int64 // Seconds to wait before trying again
state int32 // 0=closed, 1=open, 2=half-open
}
const (
circuitClosed int32 = 0
circuitOpen int32 = 1
circuitHalfOpen int32 = 2
)
// NewCircuitBreaker creates a new circuit breaker.
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
return &CircuitBreaker{
threshold: threshold,
resetTimeout: resetTimeout,
}
}
// Allow checks if a request should be allowed through.
func (cb *CircuitBreaker) Allow() bool {
state := atomic.LoadInt32(&cb.state)
if state == circuitClosed {
return true
}
if state == circuitOpen {
// Check if reset timeout has passed
lastFail := atomic.LoadInt64(&cb.lastFailure)
if time.Now().Unix()-lastFail > cb.resetTimeout {
// Transition to half-open
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
return true
}
return false
}
// Half-open: allow one request through
return true
}
// RecordSuccess records a successful call.
func (cb *CircuitBreaker) RecordSuccess() {
atomic.StoreInt64(&cb.failures, 0)
atomic.StoreInt32(&cb.state, circuitClosed)
}
// RecordFailure records a failed call.
func (cb *CircuitBreaker) RecordFailure() {
failures := atomic.AddInt64(&cb.failures, 1)
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
if failures >= cb.threshold {
atomic.StoreInt32(&cb.state, circuitOpen)
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
}
}
// State returns the current state as a string.
func (cb *CircuitBreaker) State() string {
switch atomic.LoadInt32(&cb.state) {
case circuitOpen:
return "open"
case circuitHalfOpen:
return "half-open"
default:
return "closed"
}
}
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
type CircuitBreakerMetrics struct {
State string `json:"state"`
Failures int64 `json:"failures"`
Threshold int64 `json:"threshold"`
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
}
// Metrics returns the current metrics of the circuit breaker.
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
failures := atomic.LoadInt64(&cb.failures)
lastFail := atomic.LoadInt64(&cb.lastFailure)
state := cb.State()
metrics := CircuitBreakerMetrics{
State: state,
Failures: failures,
Threshold: cb.threshold,
ResetTimeoutSecs: cb.resetTimeout,
}
if lastFail > 0 {
metrics.LastFailureUnix = lastFail
if state == "open" {
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
if remaining > 0 {
metrics.SecondsUntilReset = remaining
}
}
}
return metrics
}
// RequestDeduplicator tracks recent requests to prevent duplicates.
type RequestDeduplicator struct {
seen map[string]int64 // hash -> timestamp
mu sync.RWMutex
ttlSecs int64
maxSize int
}
// NewRequestDeduplicator creates a new deduplicator.
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
return &RequestDeduplicator{
seen: make(map[string]int64),
ttlSecs: ttlSecs,
maxSize: maxSize,
}
}
// IsDuplicate checks if a request hash was seen recently.
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
now := time.Now().Unix()
d.mu.RLock()
ts, exists := d.seen[hash]
d.mu.RUnlock()
if exists && now-ts < d.ttlSecs {
return true
}
return false
}
// Record marks a request hash as seen.
func (d *RequestDeduplicator) Record(hash string) {
now := time.Now().Unix()
d.mu.Lock()
defer d.mu.Unlock()
// Evict old entries if at capacity
if len(d.seen) >= d.maxSize {
threshold := now - d.ttlSecs
for k, ts := range d.seen {
if ts < threshold {
delete(d.seen, k)
}
}
}
d.seen[hash] = now
}
// hashRequest creates a hash of a request for deduplication.
func hashRequest(toolName, input, output string) string {
h := sha256.New()
h.Write([]byte(toolName))
h.Write([]byte(input))
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
}
// BroadcastFunc is a callback for broadcasting events to SSE clients.
type BroadcastFunc func(event map[string]interface{})
type BroadcastFunc func(event map[string]any)
// SyncObservationFunc is a callback for syncing observations to vector DB.
type SyncObservationFunc func(obs *models.Observation)
@@ -29,16 +203,26 @@ type SyncObservationFunc func(obs *models.Observation)
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
type SyncSummaryFunc func(summary *models.SessionSummary)
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
const MaxVectorSyncWorkers = 8
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
// Field order optimized for memory alignment (fieldalignment).
type Processor struct {
observationStore *gorm.ObservationStore
summaryStore *gorm.SummaryStore
broadcastFunc BroadcastFunc
syncObservationFunc SyncObservationFunc
syncSummaryFunc SyncSummaryFunc
circuitBreaker *CircuitBreaker
deduplicator *RequestDeduplicator
vectorSyncChan chan *models.Observation
vectorSyncDone chan struct{}
sem chan struct{}
claudePath string
model string
vectorSyncWg sync.WaitGroup
}
// SetBroadcastFunc sets the broadcast callback for SSE events.
@@ -57,7 +241,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
}
// broadcast sends an event via the broadcast callback if set.
func (p *Processor) broadcast(event map[string]interface{}) {
func (p *Processor) broadcast(event map[string]any) {
if p.broadcastFunc != nil {
p.broadcastFunc(event)
}
@@ -93,9 +277,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
observationStore: observationStore,
summaryStore: summaryStore,
sem: make(chan struct{}, MaxConcurrentCLICalls),
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
vectorSyncDone: make(chan struct{}),
}, nil
}
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
// Call this after setting the sync function via SetSyncObservationFunc.
func (p *Processor) StartVectorSyncWorkers() {
for i := 0; i < MaxVectorSyncWorkers; i++ {
p.vectorSyncWg.Add(1)
go p.vectorSyncWorker()
}
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
}
// StopVectorSyncWorkers gracefully stops the worker pool.
func (p *Processor) StopVectorSyncWorkers() {
close(p.vectorSyncDone)
p.vectorSyncWg.Wait()
log.Info().Msg("Vector sync worker pool stopped")
}
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
func (p *Processor) vectorSyncWorker() {
defer p.vectorSyncWg.Done()
for {
select {
case <-p.vectorSyncDone:
// Drain remaining items before exiting
for {
select {
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
default:
return
}
}
case obs := <-p.vectorSyncChan:
if p.syncObservationFunc != nil {
p.syncObservationFunc(obs)
}
}
}
}
// CircuitBreakerState returns the current state of the circuit breaker.
func (p *Processor) CircuitBreakerState() string {
return p.circuitBreaker.State()
}
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
return p.circuitBreaker.Metrics()
}
// IsAvailable checks if the Claude CLI is available for processing.
func (p *Processor) IsAvailable() bool {
_, err := os.Stat(p.claudePath)
@@ -103,7 +343,7 @@ func (p *Processor) IsAvailable() bool {
}
// ProcessObservation processes a single tool observation and extracts insights.
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
// Skip certain tools that aren't worth processing
if shouldSkipTool(toolName) {
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
@@ -120,11 +360,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
return nil
}
// Check for duplicate request within TTL window
reqHash := hashRequest(toolName, inputStr, outputStr)
if p.deduplicator.IsDuplicate(reqHash) {
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
return nil
}
// Check circuit breaker before making CLI call
if !p.circuitBreaker.Allow() {
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
return fmt.Errorf("circuit breaker open")
}
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
// Note: Removed the "file already has observations" check
// Each tool execution can produce unique insights even for the same file
// Similarity-based deduplication will handle true duplicates
// Record this request to prevent duplicates
p.deduplicator.Record(reqHash)
// Build the prompt
exec := ToolExecution{
@@ -146,9 +398,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
// Call Claude Code CLI
response, err := p.callClaudeCLI(ctx, prompt)
if err != nil {
p.circuitBreaker.RecordFailure()
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
return err
}
p.circuitBreaker.RecordSuccess()
// Parse observations from response
observations := ParseObservations(response, sdkSessionID)
@@ -199,16 +453,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
Int("trackedFiles", len(obs.FileMtimes)).
Msg("Observation stored")
// Sync to vector DB if callback is set
if p.syncObservationFunc != nil {
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
fullObs.ID = id
fullObs.CreatedAtEpoch = createdAtEpoch
p.syncObservationFunc(fullObs)
// Non-blocking send to worker pool - drops if channel is full
select {
case p.vectorSyncChan <- fullObs:
// Sent to worker pool
default:
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
go func(obsToSync *models.Observation) {
p.syncObservationFunc(obsToSync)
}(fullObs)
}
}
// Broadcast new observation event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "observation",
"action": "created",
"id": id,
@@ -310,7 +574,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
}
// Broadcast new summary event for dashboard refresh
p.broadcast(map[string]interface{}{
p.broadcast(map[string]any{
"type": "summary",
"action": "created",
"id": id,
@@ -320,8 +584,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
return nil
}
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
// This prevents resource exhaustion from extremely large prompts.
const MaxPromptSize = 100 * 1024 // 100KB
// sanitizePrompt removes null bytes and control characters from a prompt.
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
func sanitizePrompt(s string) string {
return strings.Map(func(r rune) rune {
// Keep printable ASCII, extended Unicode, and common whitespace
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
return r
}
// Remove null bytes and other control characters
return -1
}, s)
}
// callClaudeCLI calls the Claude Code CLI with the given prompt.
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
// Validate and sanitize prompt
if len(prompt) > MaxPromptSize {
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
}
prompt = sanitizePrompt(prompt)
// Build the full prompt with system instructions
fullPrompt := systemPrompt + "\n\n" + prompt
@@ -418,8 +705,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
return true
}
// Skip if output indicates an error or empty result
// Pre-compute lowercase strings once to avoid repeated allocations
lowerOutput := strings.ToLower(outputStr)
lowerInput := strings.ToLower(inputStr)
// Skip if output indicates an error or empty result
trivialOutputs := []string{
"no matches found",
"file not found",
@@ -443,13 +733,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
// Skip reading config files that rarely contain project-specific insights
boringFiles := []string{
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
".gitignore", ".dockerignore", ".eslintignore",
"tsconfig.json", "jsconfig.json", "vite.config",
"tailwind.config", "postcss.config",
}
for _, boring := range boringFiles {
if strings.Contains(inputStr, boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -461,14 +751,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
case "Bash":
// Skip simple status commands
// Skip simple status commands (use pre-computed lowerInput)
boringCommands := []string{
"git status", "git diff", "git log", "git branch",
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
"npm list", "npm outdated", "npm audit",
}
for _, boring := range boringCommands {
if strings.Contains(strings.ToLower(inputStr), boring) {
if strings.Contains(lowerInput, boring) {
return true
}
}
@@ -478,7 +768,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
}
// toJSONString converts an interface to a JSON string.
func toJSONString(v interface{}) string {
func toJSONString(v any) string {
if v == nil {
return ""
}
@@ -492,38 +782,132 @@ func toJSONString(v interface{}) string {
return string(b)
}
// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory.
// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected.
// This function is a security sanitizer for path traversal attacks.
func safeResolvePath(path, cwd string) (string, bool) {
// Clean the input path to normalize any .. or . components
cleanPath := filepath.Clean(path)
// Reject paths that explicitly contain parent directory traversal after cleaning
if strings.Contains(cleanPath, "..") {
return "", false
}
if filepath.IsAbs(cleanPath) {
// For absolute paths, verify they're within cwd if cwd is specified
if cwd != "" {
cleanCwd := filepath.Clean(cwd)
if !strings.HasPrefix(cleanPath, cleanCwd+string(filepath.Separator)) && cleanPath != cleanCwd {
return "", false
}
}
return cleanPath, true
}
if cwd == "" {
return cleanPath, true
}
// Clean the cwd first
cleanCwd := filepath.Clean(cwd)
// Join and clean the path
absPath := filepath.Join(cleanCwd, cleanPath)
// Use filepath.Rel to verify the path is actually within cwd
// If Rel returns a path starting with "..", it escapes the base
rel, err := filepath.Rel(cleanCwd, absPath)
if err != nil || strings.HasPrefix(rel, "..") {
return "", false
}
return absPath, true
}
// captureFileMtimes captures current modification times for tracked files.
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
// For large file lists (>10 files), uses parallel stat calls for better performance.
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
mtimes := make(map[string]int64)
// Combine all unique file paths
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
for _, path := range filesRead {
allPaths[path] = struct{}{}
}
for _, path := range filesModified {
allPaths[path] = struct{}{}
}
// Helper to get mtime for a file path
getMtime := func(path string) (int64, bool) {
// Resolve relative paths against cwd
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
// For small lists, use sequential processing (goroutine overhead not worth it)
if len(allPaths) <= 10 {
return captureFileMtimesSequential(allPaths, cwd)
}
// For larger lists, parallelize with bounded concurrency
return captureFileMtimesParallel(allPaths, cwd)
}
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
mtimes := make(map[string]int64, len(paths))
for path := range paths {
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Skip paths that attempt directory traversal
continue
}
info, err := os.Stat(absPath)
if err != nil {
return 0, false
}
return info.ModTime().UnixMilli(), true
}
// Capture mtimes for all read files
for _, path := range filesRead {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
if err == nil {
mtimes[path] = info.ModTime().UnixMilli()
}
}
// Capture mtimes for all modified files
for _, path := range filesModified {
if mtime, ok := getMtime(path); ok {
mtimes[path] = mtime
}
return mtimes
}
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
type mtimeResult struct {
path string
mtime int64
}
results := make(chan mtimeResult, len(paths))
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
var wg sync.WaitGroup
for path := range paths {
wg.Add(1)
go func(p string) {
defer wg.Done()
sem <- struct{}{} // Acquire
defer func() { <-sem }() // Release
absPath, ok := safeResolvePath(p, cwd)
if !ok {
// Skip paths that attempt directory traversal
return
}
info, err := os.Stat(absPath)
if err == nil {
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
}
}(path)
}
// Close results channel when all goroutines complete
go func() {
wg.Wait()
close(results)
}()
// Collect results
mtimes := make(map[string]int64, len(paths))
for res := range results {
mtimes[res.path] = res.mtime
}
return mtimes
@@ -538,12 +922,13 @@ func GetFileMtimes(paths []string, cwd string) map[string]int64 {
// GetFileContent reads file content for verification purposes.
// Returns content and ok status.
func GetFileContent(path, cwd string) (string, bool) {
absPath := path
if !filepath.IsAbs(path) && cwd != "" {
absPath = filepath.Join(cwd, path)
absPath, ok := safeResolvePath(path, cwd)
if !ok {
// Reject paths that attempt directory traversal
return "", false
}
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
content, err := os.ReadFile(absPath) // #nosec G304 -- path validated by safeResolvePath
if err != nil {
return "", false
}
+204
View File
@@ -974,3 +974,207 @@ func TestSyncSummaryFuncType(t *testing.T) {
}
assert.NotNil(t, fn)
}
// TestSanitizePrompt tests prompt sanitization for CLI safety.
func TestSanitizePrompt(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "normal text",
input: "Hello, world!",
expected: "Hello, world!",
},
{
name: "text with newlines",
input: "Line 1\nLine 2\nLine 3",
expected: "Line 1\nLine 2\nLine 3",
},
{
name: "text with tabs",
input: "Key:\tValue",
expected: "Key:\tValue",
},
{
name: "text with carriage return",
input: "Line 1\r\nLine 2",
expected: "Line 1\r\nLine 2",
},
{
name: "text with null bytes",
input: "Hello\x00World",
expected: "HelloWorld",
},
{
name: "text with control characters",
input: "Hello\x01\x02\x03World",
expected: "HelloWorld",
},
{
name: "text with bell character",
input: "Hello\x07World",
expected: "HelloWorld",
},
{
name: "text with backspace",
input: "Hello\x08World",
expected: "HelloWorld",
},
{
name: "text with form feed",
input: "Hello\x0cWorld",
expected: "HelloWorld",
},
{
name: "text with escape",
input: "Hello\x1bWorld",
expected: "HelloWorld",
},
{
name: "unicode text",
input: "Hello 世界 🌍",
expected: "Hello 世界 🌍",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "only control characters",
input: "\x00\x01\x02\x03",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizePrompt(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestMaxPromptSize tests that MaxPromptSize is reasonable.
func TestMaxPromptSize(t *testing.T) {
assert.Equal(t, 100*1024, MaxPromptSize)
}
// BenchmarkSanitizePrompt benchmarks the sanitize function.
func BenchmarkSanitizePrompt(b *testing.B) {
prompt := "Analyze the following code:\n```go\nfunc main() {\n\tfmt.Println(\"Hello, World!\")\n}\n```\n\nPlease identify any issues."
b.ResetTimer()
for i := 0; i < b.N; i++ {
sanitizePrompt(prompt)
}
}
// BenchmarkSanitizePromptWithControlChars benchmarks sanitization with control characters.
func BenchmarkSanitizePromptWithControlChars(b *testing.B) {
prompt := "Hello\x00World\x01Test\x02Data\x03End"
b.ResetTimer()
for i := 0; i < b.N; i++ {
sanitizePrompt(prompt)
}
}
// TestSafeResolvePath tests the path traversal protection.
func TestSafeResolvePath(t *testing.T) {
// Create a temporary directory for testing
tmpDir := t.TempDir()
tests := []struct {
name string
path string
cwd string
wantPath string
wantOk bool
}{
{
name: "simple relative path",
path: "file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "file.txt"),
},
{
name: "nested relative path",
path: "subdir/file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "subdir", "file.txt"),
},
{
name: "path traversal with ..",
path: "../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "path traversal with multiple ..",
path: "../../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "path traversal hidden in middle",
path: "subdir/../../../etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "just parent directory",
path: "..",
cwd: tmpDir,
wantOk: false,
},
{
name: "absolute path without cwd",
path: "/some/absolute/path",
cwd: "",
wantOk: true,
wantPath: "/some/absolute/path",
},
{
name: "relative path without cwd",
path: "relative/path",
cwd: "",
wantOk: true,
wantPath: "relative/path",
},
{
name: "current directory reference",
path: "./file.txt",
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "file.txt"),
},
{
name: "absolute path outside cwd",
path: "/etc/passwd",
cwd: tmpDir,
wantOk: false,
},
{
name: "absolute path inside cwd",
path: filepath.Join(tmpDir, "inside.txt"),
cwd: tmpDir,
wantOk: true,
wantPath: filepath.Join(tmpDir, "inside.txt"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotPath, gotOk := safeResolvePath(tt.path, tt.cwd)
assert.Equal(t, tt.wantOk, gotOk, "ok status mismatch")
if tt.wantPath != "" && gotOk {
assert.Equal(t, tt.wantPath, gotPath, "path mismatch")
}
})
}
}
+678 -443
View File
File diff suppressed because it is too large Load Diff
+16 -4
View File
@@ -5,6 +5,8 @@ import (
"io/fs"
"net/http"
"strings"
"github.com/rs/zerolog/log"
)
//go:embed static/*
@@ -13,16 +15,22 @@ var staticFS embed.FS
// staticSubFS is the static subdirectory filesystem
var staticSubFS fs.FS
// staticInitErr stores any error from static filesystem initialization
var staticInitErr error
func init() {
var err error
staticSubFS, err = fs.Sub(staticFS, "static")
if err != nil {
panic("failed to create sub filesystem: " + err.Error())
staticSubFS, staticInitErr = fs.Sub(staticFS, "static")
if staticInitErr != nil {
log.Warn().Err(staticInitErr).Msg("Static filesystem initialization failed - dashboard will be unavailable")
}
}
// serveIndex serves the index.html file for the root path
func serveIndex(w http.ResponseWriter, r *http.Request) {
if staticInitErr != nil {
http.Error(w, "Dashboard unavailable: static files not initialized", http.StatusServiceUnavailable)
return
}
content, err := fs.ReadFile(staticSubFS, "index.html")
if err != nil {
http.Error(w, "Dashboard not found", http.StatusNotFound)
@@ -38,6 +46,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
// serveAssets serves static assets from the embedded filesystem
func serveAssets(w http.ResponseWriter, r *http.Request) {
if staticInitErr != nil {
http.Error(w, "Assets unavailable: static files not initialized", http.StatusServiceUnavailable)
return
}
// Strip the /assets/ prefix and serve the file
path := strings.TrimPrefix(r.URL.Path, "/")