mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-13 02:06:24 +00:00
Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
This commit is contained in:
+125
-1273
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,677 @@
|
||||
// Package worker provides context and search-related HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/reranking"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/search/expansion"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/sdk"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleSearchByPrompt searches observations relevant to a user prompt.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
|
||||
if project == "" || query == "" {
|
||||
http.Error(w, "project and query required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit := gorm.ParseLimitParamWithMax(r, DefaultSearchLimit, 200)
|
||||
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
var usedVector bool
|
||||
similarityScores := make(map[int64]float64) // Track similarity per observation
|
||||
|
||||
// Get threshold settings from config
|
||||
threshold := s.config.ContextRelevanceThreshold
|
||||
maxResults := s.config.ContextMaxPromptResults
|
||||
|
||||
// Generate expanded queries if query expander is available
|
||||
// Use timeout context to prevent query expansion from blocking
|
||||
var expandedQueries []expansion.ExpandedQuery
|
||||
var detectedIntent string
|
||||
if s.queryExpander != nil {
|
||||
expandCtx, expandCancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
cfg := expansion.DefaultConfig()
|
||||
cfg.EnableVocabularyExpansion = false // Vocabulary expansion is optional
|
||||
expandedQueries = s.queryExpander.Expand(expandCtx, query, cfg)
|
||||
expandCancel() // Cancel immediately after use (defer not needed - no panic possible between creation and here)
|
||||
if len(expandedQueries) > 0 {
|
||||
detectedIntent = string(expandedQueries[0].Intent)
|
||||
}
|
||||
}
|
||||
if len(expandedQueries) == 0 {
|
||||
// Fallback to just the original query
|
||||
expandedQueries = []expansion.ExpandedQuery{
|
||||
{Query: query, Weight: 1.0, Source: "original"},
|
||||
}
|
||||
}
|
||||
|
||||
// Try vector search first if available
|
||||
var vectorSearchFailed bool
|
||||
if s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
|
||||
// Search with each expanded query and merge results
|
||||
// Pre-allocate with estimated capacity to avoid repeated reallocation
|
||||
estimatedCapacity := len(expandedQueries) * limit * 2
|
||||
allVectorResults := make([]sqlitevec.QueryResult, 0, estimatedCapacity)
|
||||
queryWeights := make(map[string]float64, len(expandedQueries))
|
||||
var vectorErrors int
|
||||
|
||||
for _, eq := range expandedQueries {
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), eq.Query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
vectorErrors++
|
||||
log.Debug().Err(vecErr).Str("query", eq.Query).Msg("Vector query failed")
|
||||
} else if len(vectorResults) > 0 {
|
||||
// Apply weight to similarity scores before merging
|
||||
for i := range vectorResults {
|
||||
vectorResults[i].Similarity *= eq.Weight
|
||||
}
|
||||
allVectorResults = append(allVectorResults, vectorResults...)
|
||||
queryWeights[eq.Query] = eq.Weight
|
||||
}
|
||||
}
|
||||
|
||||
// Track if vector search had issues
|
||||
if vectorErrors > 0 && vectorErrors == len(expandedQueries) {
|
||||
vectorSearchFailed = true
|
||||
log.Warn().Int("errors", vectorErrors).Str("project", project).Msg("All vector queries failed, falling back to FTS")
|
||||
}
|
||||
|
||||
if len(allVectorResults) > 0 {
|
||||
// Filter by relevance threshold before extracting IDs
|
||||
// Use a slightly lower threshold for expanded queries
|
||||
effectiveThreshold := threshold * 0.9 // Allow slightly lower scores for expanded queries
|
||||
filteredResults := sqlitevec.FilterByThreshold(allVectorResults, effectiveThreshold, 0)
|
||||
|
||||
// Build similarity map for filtered results (keeping highest weighted score per observation)
|
||||
for _, vr := range filteredResults {
|
||||
if sqliteID, ok := vr.Metadata["sqlite_id"].(float64); ok {
|
||||
id := int64(sqliteID)
|
||||
// Keep the highest score for each observation
|
||||
if existing, exists := similarityScores[id]; !exists || vr.Similarity > existing {
|
||||
similarityScores[id] = vr.Similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract observation IDs with project/scope filtering using shared helper
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(filteredResults, project)
|
||||
|
||||
if len(obsIDs) > 0 {
|
||||
// Fetch full observations from SQLite
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to FTS if vector search not available, failed, or returned no results
|
||||
if !usedVector || len(observations) == 0 {
|
||||
if vectorSearchFailed {
|
||||
log.Info().Str("project", project).Msg("Using FTS fallback due to vector search failure")
|
||||
}
|
||||
observations, err = s.observationStore.SearchObservationsFTS(r.Context(), query, project, limit)
|
||||
if err != nil {
|
||||
// FTS might fail if query has special chars, try without
|
||||
log.Warn().Err(err).Str("query", query).Msg("FTS search failed, falling back to recent")
|
||||
observations, err = s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for interactive use)
|
||||
// Just check mtimes and exclude obviously stale observations
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 && cwd != "" {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Apply cross-encoder reranking if available
|
||||
var reranked bool
|
||||
if s.reranker != nil && len(freshObservations) > 0 && usedVector {
|
||||
// Build candidates from observations with their bi-encoder scores
|
||||
candidates := make([]reranking.Candidate, len(freshObservations))
|
||||
for i, obs := range freshObservations {
|
||||
// Use strings.Builder for efficient concatenation
|
||||
var content string
|
||||
if obs.Narrative.Valid && obs.Narrative.String != "" {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(obs.Title.String) + 1 + len(obs.Narrative.String))
|
||||
sb.WriteString(obs.Title.String)
|
||||
sb.WriteByte(' ')
|
||||
sb.WriteString(obs.Narrative.String)
|
||||
content = sb.String()
|
||||
} else {
|
||||
content = obs.Title.String
|
||||
}
|
||||
candidates[i] = reranking.Candidate{
|
||||
ID: strconv.FormatInt(obs.ID, 10), // Faster than fmt.Sprintf
|
||||
Content: content,
|
||||
Score: similarityScores[obs.ID],
|
||||
Metadata: map[string]any{"obs_idx": i},
|
||||
}
|
||||
}
|
||||
|
||||
// Rerank using cross-encoder - use pure mode or combined scores
|
||||
var rerankResults []reranking.RerankResult
|
||||
var rerankErr error
|
||||
if s.config.RerankingPureMode {
|
||||
rerankResults, rerankErr = s.reranker.RerankByScore(query, candidates, s.config.RerankingResults)
|
||||
} else {
|
||||
rerankResults, rerankErr = s.reranker.Rerank(query, candidates, s.config.RerankingResults)
|
||||
}
|
||||
if rerankErr != nil {
|
||||
log.Warn().Err(rerankErr).Msg("Cross-encoder reranking failed, using original order")
|
||||
} else if len(rerankResults) > 0 {
|
||||
// Update similarity scores with reranked scores
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
similarityScores[id] = rr.CombinedScore
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder observations based on rerank results
|
||||
reorderedObs := make([]*models.Observation, 0, len(rerankResults))
|
||||
obsMap := make(map[int64]*models.Observation)
|
||||
for _, obs := range freshObservations {
|
||||
obsMap[obs.ID] = obs
|
||||
}
|
||||
for _, rr := range rerankResults {
|
||||
if id, err := strconv.ParseInt(rr.ID, 10, 64); err == nil {
|
||||
if obs, ok := obsMap[id]; ok {
|
||||
reorderedObs = append(reorderedObs, obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
freshObservations = reorderedObs
|
||||
reranked = true
|
||||
|
||||
log.Debug().
|
||||
Int("candidates", len(candidates)).
|
||||
Int("returned", len(rerankResults)).
|
||||
Msg("Cross-encoder reranking complete")
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
|
||||
|
||||
// Sort by similarity score (highest first) if we have scores and didn't rerank
|
||||
if len(similarityScores) > 0 && len(clusteredObservations) > 0 && !reranked {
|
||||
sort.Slice(clusteredObservations, func(i, j int) bool {
|
||||
scoreI := similarityScores[clusteredObservations[i].ID]
|
||||
scoreJ := similarityScores[clusteredObservations[j].ID]
|
||||
return scoreI > scoreJ
|
||||
})
|
||||
}
|
||||
|
||||
// Apply max results cap if configured
|
||||
if maxResults > 0 && len(clusteredObservations) > maxResults {
|
||||
clusteredObservations = clusteredObservations[:maxResults]
|
||||
}
|
||||
|
||||
// Record retrieval stats with staleness metrics
|
||||
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
|
||||
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), true)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Str("query", query).
|
||||
Str("intent", detectedIntent).
|
||||
Int("expansions", len(expandedQueries)).
|
||||
Int("found", len(clusteredObservations)).
|
||||
Int("stale_excluded", staleCount).
|
||||
Float64("threshold", threshold).
|
||||
Msg("Prompt-based observation search")
|
||||
|
||||
// Build response with similarity scores
|
||||
obsWithScores := make([]map[string]any, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
obsMap := obs.ToMap()
|
||||
if score, ok := similarityScores[obs.ID]; ok {
|
||||
obsMap["similarity"] = score
|
||||
}
|
||||
obsWithScores[i] = obsMap
|
||||
}
|
||||
|
||||
// Build expansion info for response
|
||||
expansionInfo := make([]map[string]any, len(expandedQueries))
|
||||
for i, eq := range expandedQueries {
|
||||
expansionInfo[i] = map[string]any{
|
||||
"query": eq.Query,
|
||||
"weight": eq.Weight,
|
||||
"source": eq.Source,
|
||||
}
|
||||
}
|
||||
|
||||
// Track this search for analytics
|
||||
s.trackSearchQuery(query, project, "observations", len(clusteredObservations), usedVector)
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"query": query,
|
||||
"intent": detectedIntent,
|
||||
"expansions": expansionInfo,
|
||||
"observations": obsWithScores,
|
||||
"threshold": threshold,
|
||||
"max_results": maxResults,
|
||||
})
|
||||
}
|
||||
|
||||
// handleFileContext returns observations relevant to specific files being worked on.
|
||||
// Uses vector similarity search to find observations that mention or relate to the given files.
|
||||
func (s *Service) handleFileContext(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
filesParam := r.URL.Query().Get("files")
|
||||
if filesParam == "" {
|
||||
http.Error(w, "files required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse comma-separated file paths
|
||||
files := strings.Split(filesParam, ",")
|
||||
if len(files) == 0 {
|
||||
http.Error(w, "at least one file required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to reasonable number of files
|
||||
maxFiles := 20
|
||||
if len(files) > maxFiles {
|
||||
files = files[:maxFiles]
|
||||
}
|
||||
|
||||
// Get limit parameter (default 10 per file)
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
limit := 10
|
||||
if limitStr != "" {
|
||||
if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 && parsed <= 50 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// Search for observations related to each file in parallel
|
||||
ctx := r.Context()
|
||||
|
||||
// Check if vector search is available
|
||||
if s.vectorClient == nil || !s.vectorClient.IsConnected() {
|
||||
writeJSON(w, map[string]any{
|
||||
"files": files,
|
||||
"results": map[string]any{},
|
||||
"count": 0,
|
||||
"error": "vector search not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare for parallel execution
|
||||
type fileResult struct {
|
||||
file string
|
||||
results []map[string]any
|
||||
obsIDs []int64 // Track observation IDs for deduplication
|
||||
}
|
||||
|
||||
resultsChan := make(chan fileResult, len(files))
|
||||
sem := make(chan struct{}, 5) // Limit concurrency to 5 parallel searches
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, file := range files {
|
||||
file = strings.TrimSpace(file)
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(file string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire semaphore
|
||||
defer func() { <-sem }() // Release semaphore
|
||||
|
||||
// Build search query from file path
|
||||
query := buildFileQuery(file)
|
||||
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(ctx, query, limit*2, where)
|
||||
if vecErr != nil {
|
||||
log.Warn().Err(vecErr).Str("file", file).Msg("Vector search failed for file context")
|
||||
return
|
||||
}
|
||||
|
||||
// Extract observation IDs from vector results
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch observations
|
||||
observations, err := s.observationStore.GetObservationsByIDs(ctx, obsIDs, "score_desc", limit*2)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("file", file).Msg("Failed to fetch observations for file context")
|
||||
return
|
||||
}
|
||||
|
||||
// Pre-build score map from vector results (O(n) instead of O(n²))
|
||||
scoreMap := make(map[int64]float64, len(vectorResults))
|
||||
var avgScore float64
|
||||
for _, vr := range vectorResults {
|
||||
avgScore += vr.Similarity
|
||||
// Parse observation ID from vector result ID (format: obs_{id}_{field})
|
||||
// Use index-based parsing to avoid slice allocation from strings.Split
|
||||
if len(vr.ID) > 4 && vr.ID[:4] == "obs_" {
|
||||
rest := vr.ID[4:] // Skip "obs_"
|
||||
underscoreIdx := strings.IndexByte(rest, '_')
|
||||
var idStr string
|
||||
if underscoreIdx >= 0 {
|
||||
idStr = rest[:underscoreIdx]
|
||||
} else {
|
||||
idStr = rest
|
||||
}
|
||||
if id, parseErr := strconv.ParseInt(idStr, 10, 64); parseErr == nil {
|
||||
// Keep highest score for each observation
|
||||
if existing, exists := scoreMap[id]; !exists || vr.Similarity > existing {
|
||||
scoreMap[id] = vr.Similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(vectorResults) > 0 {
|
||||
avgScore /= float64(len(vectorResults))
|
||||
}
|
||||
|
||||
fileResults := make([]map[string]any, 0, limit)
|
||||
var usedIDs []int64
|
||||
for _, obs := range observations {
|
||||
// Check project scope
|
||||
if obs.Scope == "project" && obs.Project != project {
|
||||
continue
|
||||
}
|
||||
|
||||
// O(1) score lookup instead of O(n) nested loop
|
||||
score, found := scoreMap[obs.ID]
|
||||
if !found {
|
||||
// Use average score as fallback
|
||||
score = avgScore
|
||||
}
|
||||
|
||||
// Only include if score is above threshold
|
||||
if score < 0.3 {
|
||||
continue
|
||||
}
|
||||
|
||||
fileResults = append(fileResults, map[string]any{
|
||||
"id": obs.ID,
|
||||
"title": obs.Title.String,
|
||||
"type": obs.Type,
|
||||
"narrative": obs.Narrative.String,
|
||||
"facts": obs.Facts,
|
||||
"score": score,
|
||||
})
|
||||
usedIDs = append(usedIDs, obs.ID)
|
||||
|
||||
if len(fileResults) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(fileResults) > 0 {
|
||||
resultsChan <- fileResult{file: file, results: fileResults, obsIDs: usedIDs}
|
||||
}
|
||||
}(file)
|
||||
}
|
||||
|
||||
// Close channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsChan)
|
||||
}()
|
||||
|
||||
// Collect results and deduplicate
|
||||
allResults := make(map[string]any)
|
||||
seenObservationIDs := make(map[int64]bool)
|
||||
|
||||
for res := range resultsChan {
|
||||
// Filter out duplicates that were already seen in other files
|
||||
dedupedResults := make([]map[string]any, 0, len(res.results))
|
||||
for i, r := range res.results {
|
||||
obsID := res.obsIDs[i]
|
||||
if !seenObservationIDs[obsID] {
|
||||
seenObservationIDs[obsID] = true
|
||||
dedupedResults = append(dedupedResults, r)
|
||||
}
|
||||
}
|
||||
if len(dedupedResults) > 0 {
|
||||
allResults[res.file] = dedupedResults
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"files": files,
|
||||
"results": allResults,
|
||||
"count": len(allResults),
|
||||
})
|
||||
}
|
||||
|
||||
// buildFileQuery extracts meaningful search terms from a file path.
|
||||
func buildFileQuery(filePath string) string {
|
||||
// Remove common prefixes and extensions
|
||||
path := strings.TrimPrefix(filePath, "/")
|
||||
|
||||
// Extract the filename and directory
|
||||
parts := strings.Split(path, "/")
|
||||
meaningful := make([]string, 0, len(parts))
|
||||
|
||||
for _, part := range parts {
|
||||
// Skip common directory names that aren't meaningful
|
||||
switch strings.ToLower(part) {
|
||||
case "src", "lib", "internal", "pkg", "cmd", "api", "app", "test", "tests", "spec", "specs":
|
||||
continue
|
||||
default:
|
||||
// Remove file extension
|
||||
if idx := strings.LastIndex(part, "."); idx > 0 {
|
||||
part = part[:idx]
|
||||
}
|
||||
// Convert camelCase/PascalCase to spaces
|
||||
part = splitCamelCase(part)
|
||||
// Convert snake_case to spaces
|
||||
part = strings.ReplaceAll(part, "_", " ")
|
||||
// Convert kebab-case to spaces
|
||||
part = strings.ReplaceAll(part, "-", " ")
|
||||
meaningful = append(meaningful, part)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(meaningful, " ")
|
||||
}
|
||||
|
||||
// splitCamelCase splits camelCase or PascalCase into separate words.
|
||||
func splitCamelCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune(' ')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// handleContextInject returns context for injection at session start.
|
||||
// IMPORTANT: This is on the critical startup path - must be fast!
|
||||
// No synchronous verification - just filter by staleness and return.
|
||||
func (s *Service) handleContextInject(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cwd := r.URL.Query().Get("cwd")
|
||||
if cwd == "" {
|
||||
cwd = "/"
|
||||
}
|
||||
|
||||
// Limit observations for fast startup (configurable, default 100)
|
||||
limit := s.config.ContextObservations
|
||||
if limit <= 0 {
|
||||
limit = DefaultContextLimit
|
||||
}
|
||||
|
||||
// Full count determines how many observations get full detail (configurable, default 25)
|
||||
fullCount := s.config.ContextFullCount
|
||||
if fullCount <= 0 {
|
||||
fullCount = 25
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
observations, err := s.observationStore.GetRecentObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Fast staleness filter - NO verification (that's too slow for startup)
|
||||
var staleCount int
|
||||
freshObservations := make([]*models.Observation, 0, len(observations))
|
||||
|
||||
for _, obs := range observations {
|
||||
if len(obs.FileMtimes) > 0 {
|
||||
var paths []string
|
||||
for path := range obs.FileMtimes {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
currentMtimes := sdk.GetFileMtimes(paths, cwd)
|
||||
|
||||
if obs.CheckStaleness(currentMtimes) {
|
||||
// Stale - exclude but don't verify (too slow)
|
||||
// Queue for background verification instead
|
||||
staleCount++
|
||||
s.queueStaleVerification(obs.ID, cwd)
|
||||
continue
|
||||
}
|
||||
}
|
||||
freshObservations = append(freshObservations, obs)
|
||||
}
|
||||
|
||||
// Cluster similar observations to remove duplicates
|
||||
clusteredObservations := clusterObservations(freshObservations, 0.4)
|
||||
duplicatesRemoved := len(freshObservations) - len(clusteredObservations)
|
||||
|
||||
// Record retrieval stats with staleness metrics
|
||||
s.recordRetrievalStatsExtended(project, int64(len(clusteredObservations)), 0, 0,
|
||||
int64(staleCount), int64(len(freshObservations)), int64(duplicatesRemoved), false)
|
||||
|
||||
// Increment retrieval counts for scoring (async, non-blocking)
|
||||
if len(clusteredObservations) > 0 {
|
||||
ids := make([]int64, len(clusteredObservations))
|
||||
for i, obs := range clusteredObservations {
|
||||
ids[i] = obs.ID
|
||||
}
|
||||
s.incrementRetrievalCounts(ids)
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", project).
|
||||
Int("total", len(observations)).
|
||||
Int("fresh", len(freshObservations)).
|
||||
Int("clustered", len(clusteredObservations)).
|
||||
Int("duplicates", duplicatesRemoved).
|
||||
Int("stale_excluded", staleCount).
|
||||
Msg("Context injection with clustering")
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"observations": clusteredObservations,
|
||||
"full_count": fullCount,
|
||||
"stale_excluded": staleCount,
|
||||
"duplicates_removed": duplicatesRemoved,
|
||||
})
|
||||
}
|
||||
|
||||
// handleContextCount returns the count of observations for a project.
|
||||
func (s *Service) handleContextCount(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
if project == "" {
|
||||
http.Error(w, "project required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"project": project,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
// Package worker provides data retrieval HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/embedding"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/vector/sqlitevec"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleGetObservations returns recent observations.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
// Supports pagination via limit and offset query parameters.
|
||||
func (s *Service) handleGetObservations(w http.ResponseWriter, r *http.Request) {
|
||||
pagination := gorm.ParsePaginationParams(r, DefaultObservationsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var observations []*models.Observation
|
||||
var total int64
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, pagination.Limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
obsIDs := sqlitevec.ExtractObservationIDs(vectorResults, project)
|
||||
if len(obsIDs) > 0 {
|
||||
observations, err = s.observationStore.GetObservationsByIDs(r.Context(), obsIDs, "date_desc", pagination.Limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
total = int64(len(observations)) // Vector search doesn't have total, use returned count
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
// Strict project filtering for dashboard - only observations from this project
|
||||
observations, total, err = s.observationStore.GetObservationsByProjectStrictPaginated(r.Context(), project, pagination.Limit, pagination.Offset)
|
||||
} else {
|
||||
// All projects
|
||||
observations, total, err = s.observationStore.GetAllRecentObservationsPaginated(r.Context(), pagination.Limit, pagination.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
// Track search if query was provided
|
||||
if query != "" {
|
||||
s.trackSearchQuery(query, project, "observations", len(observations), usedVector)
|
||||
}
|
||||
|
||||
// Return paginated response
|
||||
writeJSON(w, map[string]any{
|
||||
"observations": observations,
|
||||
"total": total,
|
||||
"limit": pagination.Limit,
|
||||
"offset": pagination.Offset,
|
||||
"hasMore": int64(pagination.Offset)+int64(len(observations)) < total,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetSummaries returns recent summaries.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetSummaries(w http.ResponseWriter, r *http.Request) {
|
||||
limit := gorm.ParseLimitParam(r, DefaultSummariesLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var summaries []*models.SessionSummary
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeSessionSummary, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
summaryIDs := sqlitevec.ExtractSummaryIDs(vectorResults, project)
|
||||
if len(summaryIDs) > 0 {
|
||||
summaries, err = s.summaryStore.GetSummariesByIDs(r.Context(), summaryIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
summaries, err = s.summaryStore.GetRecentSummaries(r.Context(), project, limit)
|
||||
} else {
|
||||
summaries, err = s.summaryStore.GetAllRecentSummaries(r.Context(), limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if summaries == nil {
|
||||
summaries = []*models.SessionSummary{}
|
||||
}
|
||||
writeJSON(w, summaries)
|
||||
}
|
||||
|
||||
// handleGetPrompts returns recent user prompts.
|
||||
// Supports optional query parameter for semantic search via sqlite-vec.
|
||||
func (s *Service) handleGetPrompts(w http.ResponseWriter, r *http.Request) {
|
||||
limit := gorm.ParseLimitParam(r, DefaultPromptsLimit)
|
||||
project := r.URL.Query().Get("project")
|
||||
query := r.URL.Query().Get("query")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var prompts []*models.UserPromptWithSession
|
||||
var err error
|
||||
var usedVector bool
|
||||
|
||||
// Use vector search if query is provided and vector client is available
|
||||
if query != "" && s.vectorClient != nil && s.vectorClient.IsConnected() {
|
||||
where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeUserPrompt, "")
|
||||
vectorResults, vecErr := s.vectorClient.Query(r.Context(), query, limit*2, where)
|
||||
if vecErr == nil && len(vectorResults) > 0 {
|
||||
promptIDs := sqlitevec.ExtractPromptIDs(vectorResults, project)
|
||||
if len(promptIDs) > 0 {
|
||||
prompts, err = s.promptStore.GetPromptsByIDs(r.Context(), promptIDs, "date_desc", limit)
|
||||
if err == nil {
|
||||
usedVector = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to SQLite if vector search not used
|
||||
if !usedVector {
|
||||
if project != "" {
|
||||
prompts, err = s.promptStore.GetRecentUserPromptsByProject(r.Context(), project, limit)
|
||||
} else {
|
||||
prompts, err = s.promptStore.GetAllRecentUserPrompts(r.Context(), limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we return empty array, not null
|
||||
if prompts == nil {
|
||||
prompts = []*models.UserPromptWithSession{}
|
||||
}
|
||||
writeJSON(w, prompts)
|
||||
}
|
||||
|
||||
// handleGetProjects returns all projects.
|
||||
// Response is cacheable for 5 minutes since project list changes infrequently.
|
||||
func (s *Service) handleGetProjects(w http.ResponseWriter, r *http.Request) {
|
||||
projects, err := s.sessionStore.GetAllProjects(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Cache for 5 minutes - project list changes infrequently
|
||||
w.Header().Set("Cache-Control", "public, max-age=300")
|
||||
writeJSON(w, projects)
|
||||
}
|
||||
|
||||
// handleGetTypes returns the canonical list of observation and concept types.
|
||||
// This provides a single source of truth for both backend and frontend.
|
||||
// Response is cacheable as these values never change at runtime.
|
||||
func (s *Service) handleGetTypes(w http.ResponseWriter, r *http.Request) {
|
||||
// Cache for 24 hours - these values are compile-time constants
|
||||
w.Header().Set("Cache-Control", "public, max-age=86400")
|
||||
writeJSON(w, map[string]any{
|
||||
"observation_types": ObservationTypes,
|
||||
"concept_types": ConceptTypes,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetModels returns available embedding models.
|
||||
// Response is cacheable as model list doesn't change without restart.
|
||||
func (s *Service) handleGetModels(w http.ResponseWriter, _ *http.Request) {
|
||||
// Cache for 1 hour - model list is static during runtime
|
||||
w.Header().Set("Cache-Control", "public, max-age=3600")
|
||||
|
||||
models := embedding.ListModels()
|
||||
defaultModel := embedding.GetDefaultModel()
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"models": models,
|
||||
"default": defaultModel,
|
||||
"current": s.embedSvc.Version(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetStats returns worker statistics.
|
||||
func (s *Service) handleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
retrievalStats := s.GetRetrievalStats(project)
|
||||
sessionsToday, _ := s.sessionStore.GetSessionsToday(r.Context())
|
||||
|
||||
response := map[string]any{
|
||||
"uptime": time.Since(s.startTime).String(),
|
||||
"uptimeSeconds": time.Since(s.startTime).Seconds(),
|
||||
"activeSessions": s.sessionManager.GetActiveSessionCount(),
|
||||
"queueDepth": s.sessionManager.GetTotalQueueDepth(),
|
||||
"isProcessing": s.sessionManager.IsAnySessionProcessing(),
|
||||
"connectedClients": s.sseBroadcaster.ClientCount(),
|
||||
"sessionsToday": sessionsToday,
|
||||
"retrieval": retrievalStats,
|
||||
"ready": s.ready.Load(),
|
||||
}
|
||||
|
||||
// Add memory stats
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
response["memory"] = map[string]any{
|
||||
"alloc_mb": float64(memStats.Alloc) / 1024 / 1024,
|
||||
"total_alloc_mb": float64(memStats.TotalAlloc) / 1024 / 1024,
|
||||
"sys_mb": float64(memStats.Sys) / 1024 / 1024,
|
||||
"heap_alloc_mb": float64(memStats.HeapAlloc) / 1024 / 1024,
|
||||
"heap_inuse_mb": float64(memStats.HeapInuse) / 1024 / 1024,
|
||||
"heap_objects": memStats.HeapObjects,
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
"gc_cycles": memStats.NumGC,
|
||||
"gc_pause_total_ms": float64(memStats.PauseTotalNs) / 1e6,
|
||||
}
|
||||
|
||||
// Add database health if available
|
||||
if s.store != nil {
|
||||
dbHealth := s.store.HealthCheck(r.Context())
|
||||
response["database"] = map[string]any{
|
||||
"status": dbHealth.Status,
|
||||
"query_latency_ms": float64(dbHealth.QueryLatency) / 1e6,
|
||||
"pool": dbHealth.PoolStats,
|
||||
"warning": dbHealth.Warning,
|
||||
}
|
||||
}
|
||||
|
||||
// Add embedding model info
|
||||
if s.embedSvc != nil {
|
||||
response["embeddingModel"] = map[string]any{
|
||||
"name": s.embedSvc.Name(),
|
||||
"version": s.embedSvc.Version(),
|
||||
"dimensions": s.embedSvc.Dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector cache stats
|
||||
if s.vectorClient != nil {
|
||||
if count, err := s.vectorClient.Count(r.Context()); err == nil {
|
||||
response["vectorCount"] = count
|
||||
}
|
||||
cacheSize, cacheMax := s.vectorClient.CacheStats()
|
||||
response["vectorCache"] = map[string]any{
|
||||
"size": cacheSize,
|
||||
"max_size": cacheMax,
|
||||
}
|
||||
}
|
||||
|
||||
// Include project-specific observation count if project is specified
|
||||
if project != "" {
|
||||
count, err := s.getCachedObservationCount(r.Context(), project)
|
||||
if err == nil {
|
||||
response["projectObservations"] = count
|
||||
response["project"] = project
|
||||
}
|
||||
}
|
||||
|
||||
// Add rate limiter stats
|
||||
if s.rateLimiter != nil {
|
||||
response["rateLimiter"] = s.rateLimiter.Stats()
|
||||
}
|
||||
|
||||
// Add circuit breaker metrics
|
||||
if s.processor != nil {
|
||||
response["circuitBreaker"] = s.processor.CircuitBreakerMetrics()
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleGetRetrievalStats returns detailed retrieval statistics.
|
||||
func (s *Service) handleGetRetrievalStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
stats := s.GetRetrievalStats(project)
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleGetRecentQueries returns recent search queries for analytics.
|
||||
func (s *Service) handleGetRecentQueries(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
limit := gorm.ParseLimitParam(r, 20)
|
||||
|
||||
queries := s.getRecentSearchQueries(project, limit)
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"queries": queries,
|
||||
"count": len(queries),
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetSearchAnalytics returns comprehensive search analytics and statistics.
|
||||
func (s *Service) handleGetSearchAnalytics(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
// Get all recent queries for analysis
|
||||
queries := s.getRecentSearchQueries(project, maxRecentQueries)
|
||||
|
||||
// Calculate analytics
|
||||
totalQueries := len(queries)
|
||||
vectorSearches := 0
|
||||
totalResults := 0
|
||||
zeroResultQueries := 0
|
||||
queryTypes := make(map[string]int)
|
||||
topKeywords := make(map[string]int)
|
||||
|
||||
for _, q := range queries {
|
||||
if q.UsedVector {
|
||||
vectorSearches++
|
||||
}
|
||||
totalResults += q.Results
|
||||
if q.Results == 0 {
|
||||
zeroResultQueries++
|
||||
}
|
||||
queryTypes[q.Type]++
|
||||
|
||||
// Extract keywords (simple word tokenization using iterator)
|
||||
for word := range strings.FieldsSeq(strings.ToLower(q.Query)) {
|
||||
if len(word) > 3 { // Skip short words
|
||||
topKeywords[word]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort keywords by frequency
|
||||
type keywordCount struct {
|
||||
Keyword string `json:"keyword"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
sortedKeywords := make([]keywordCount, 0, len(topKeywords))
|
||||
for kw, count := range topKeywords {
|
||||
sortedKeywords = append(sortedKeywords, keywordCount{Keyword: kw, Count: count})
|
||||
}
|
||||
sort.Slice(sortedKeywords, func(i, j int) bool {
|
||||
return sortedKeywords[i].Count > sortedKeywords[j].Count
|
||||
})
|
||||
if len(sortedKeywords) > 10 {
|
||||
sortedKeywords = sortedKeywords[:10]
|
||||
}
|
||||
|
||||
// Calculate averages
|
||||
avgResults := float64(0)
|
||||
vectorSearchRate := float64(0)
|
||||
zeroResultRate := float64(0)
|
||||
if totalQueries > 0 {
|
||||
avgResults = float64(totalResults) / float64(totalQueries)
|
||||
vectorSearchRate = float64(vectorSearches) / float64(totalQueries) * 100
|
||||
zeroResultRate = float64(zeroResultQueries) / float64(totalQueries) * 100
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"total_queries": totalQueries,
|
||||
"vector_search_rate": vectorSearchRate,
|
||||
"avg_results": avgResults,
|
||||
"zero_result_rate": zeroResultRate,
|
||||
"query_types": queryTypes,
|
||||
"top_keywords": sortedKeywords,
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
|
||||
// handleVectorHealth returns comprehensive health information about the vector database.
|
||||
func (s *Service) handleVectorHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if s.vectorClient == nil {
|
||||
http.Error(w, "vector client not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := s.vectorClient.GetHealthStats(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add additional computed metrics
|
||||
healthScore := 100.0
|
||||
var warnings []string
|
||||
|
||||
// Penalize for stale vectors
|
||||
if stats.TotalVectors > 0 {
|
||||
staleRatio := float64(stats.StaleVectors) / float64(stats.TotalVectors)
|
||||
if staleRatio > 0 {
|
||||
healthScore -= staleRatio * 50 // Up to 50 points off for stale vectors
|
||||
warnings = append(warnings, formatWarning("%.1f%% vectors need rebuild", staleRatio*100))
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache effectiveness
|
||||
cacheHitRate := stats.EmbeddingCache.HitRate()
|
||||
if cacheHitRate < 20 && (stats.EmbeddingCache.EmbeddingHits+stats.EmbeddingCache.EmbeddingMisses) > 100 {
|
||||
healthScore -= 10
|
||||
warnings = append(warnings, formatWarning("Low cache hit rate: %.1f%%", cacheHitRate))
|
||||
}
|
||||
|
||||
// Penalize if rebuild is needed
|
||||
if stats.NeedsRebuild {
|
||||
healthScore -= 20
|
||||
warnings = append(warnings, "Vector rebuild recommended: "+stats.RebuildReason)
|
||||
}
|
||||
|
||||
if healthScore < 0 {
|
||||
healthScore = 0
|
||||
}
|
||||
|
||||
status := "healthy"
|
||||
if healthScore < 50 {
|
||||
status = "unhealthy"
|
||||
} else if healthScore < 80 {
|
||||
status = "degraded"
|
||||
}
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"status": status,
|
||||
"health_score": healthScore,
|
||||
"warnings": warnings,
|
||||
"stats": stats,
|
||||
"cache_hit_rate": cacheHitRate,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateObservationRequest is the request body for updating an observation.
|
||||
type UpdateObservationRequest struct {
|
||||
Title *string `json:"title,omitempty"`
|
||||
Subtitle *string `json:"subtitle,omitempty"`
|
||||
Narrative *string `json:"narrative,omitempty"`
|
||||
Scope *string `json:"scope,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
}
|
||||
|
||||
// handleUpdateObservation updates an existing observation.
|
||||
// PUT /api/observations/{id}
|
||||
func (s *Service) handleUpdateObservation(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse observation ID from URL
|
||||
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req UpdateObservationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Build update struct - only include fields that were provided
|
||||
update := &gorm.ObservationUpdate{}
|
||||
|
||||
if req.Title != nil {
|
||||
update.Title = req.Title
|
||||
}
|
||||
if req.Subtitle != nil {
|
||||
update.Subtitle = req.Subtitle
|
||||
}
|
||||
if req.Narrative != nil {
|
||||
update.Narrative = req.Narrative
|
||||
}
|
||||
if req.Facts != nil {
|
||||
update.Facts = &req.Facts
|
||||
}
|
||||
if req.Concepts != nil {
|
||||
update.Concepts = &req.Concepts
|
||||
}
|
||||
if req.FilesRead != nil {
|
||||
update.FilesRead = &req.FilesRead
|
||||
}
|
||||
if req.FilesModified != nil {
|
||||
update.FilesModified = &req.FilesModified
|
||||
}
|
||||
if req.Scope != nil {
|
||||
// Validate scope
|
||||
if *req.Scope != "project" && *req.Scope != "global" {
|
||||
http.Error(w, "scope must be 'project' or 'global'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
update.Scope = req.Scope
|
||||
}
|
||||
|
||||
// Update the observation
|
||||
updatedObs, err := s.observationStore.UpdateObservation(r.Context(), id, update)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "failed to update observation: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger vector resync for the updated observation
|
||||
if s.vectorSync != nil {
|
||||
s.asyncVectorSync(func() {
|
||||
if err := s.vectorSync.SyncObservation(s.ctx, updatedObs); err != nil {
|
||||
log.Warn().Err(err).Int64("id", id).Msg("Failed to resync observation vectors after update")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Broadcast update event
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "observation_updated",
|
||||
"id": id,
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"observation": updatedObs,
|
||||
"message": "observation updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetObservationByID returns a single observation by ID.
|
||||
// GET /api/observations/{id}
|
||||
func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := parseIDParam(w, r.PathValue("id"), "observation")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
obs, err := s.observationStore.GetObservationByID(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get observation: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if obs == nil {
|
||||
http.Error(w, "observation not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, obs)
|
||||
}
|
||||
@@ -0,0 +1,680 @@
|
||||
// Package worker provides import, export, and archive HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/db/gorm"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/similarity"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BulkImportRequest is the request body for bulk observation import.
|
||||
type BulkImportRequest struct {
|
||||
Project string `json:"project"`
|
||||
Observations []BulkObservationInput `json:"observations"`
|
||||
}
|
||||
|
||||
// BulkObservationInput represents a single observation in bulk import.
|
||||
type BulkObservationInput struct {
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Subtitle string `json:"subtitle,omitempty"`
|
||||
Narrative string `json:"narrative,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Facts []string `json:"facts,omitempty"`
|
||||
Concepts []string `json:"concepts,omitempty"`
|
||||
FilesRead []string `json:"files_read,omitempty"`
|
||||
FilesModified []string `json:"files_modified,omitempty"`
|
||||
}
|
||||
|
||||
// BulkImportResponse contains the result of a bulk import operation.
|
||||
type BulkImportResponse struct {
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Imported int `json:"imported"`
|
||||
Failed int `json:"failed"`
|
||||
SkippedDuplicates int `json:"skipped_duplicates,omitempty"`
|
||||
}
|
||||
|
||||
// handleBulkImport handles bulk import of observations.
|
||||
// This is useful for migrating data or importing observations from external sources.
|
||||
func (s *Service) handleBulkImport(w http.ResponseWriter, r *http.Request) {
|
||||
// Rate limit bulk operations to prevent DoS
|
||||
if s.bulkOpLimiter != nil && !s.bulkOpLimiter.CanExecute() {
|
||||
remaining := s.bulkOpLimiter.CooldownRemaining()
|
||||
http.Error(w, fmt.Sprintf("bulk import rate limited, retry in %d seconds", remaining), http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
var req BulkImportRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Project == "" {
|
||||
http.Error(w, "project is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate project name to prevent path traversal
|
||||
if err := ValidateProjectName(req.Project); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Observations) == 0 {
|
||||
http.Error(w, "at least one observation is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Limit batch size to prevent overwhelming the system
|
||||
maxBatchSize := 100
|
||||
if len(req.Observations) > maxBatchSize {
|
||||
http.Error(w, fmt.Sprintf("batch size exceeds maximum of %d", maxBatchSize), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a synthetic session for bulk import
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), fmt.Sprintf("bulk-import-%d", time.Now().UnixMilli()), req.Project, "bulk import")
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create import session: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var imported, failed, skippedDupes int
|
||||
var errors []string
|
||||
|
||||
// Track imported observations for deduplication within the batch
|
||||
importedObs := make([]*models.Observation, 0, len(req.Observations))
|
||||
|
||||
// Deduplication threshold - observations more similar than this are considered duplicates
|
||||
const dedupThreshold = 0.7
|
||||
|
||||
for i, obsInput := range req.Observations {
|
||||
// Validate observation type using O(1) map lookup
|
||||
if !IsValidObservationType(obsInput.Type) {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("observation %d: invalid type '%s'", i, obsInput.Type))
|
||||
continue
|
||||
}
|
||||
|
||||
// Build parsed observation
|
||||
parsedObs := &models.ParsedObservation{
|
||||
Type: models.ObservationType(obsInput.Type),
|
||||
Title: obsInput.Title,
|
||||
Subtitle: obsInput.Subtitle,
|
||||
Facts: obsInput.Facts,
|
||||
Narrative: obsInput.Narrative,
|
||||
Concepts: obsInput.Concepts,
|
||||
FilesRead: obsInput.FilesRead,
|
||||
FilesModified: obsInput.FilesModified,
|
||||
Scope: models.ObservationScope(obsInput.Scope),
|
||||
}
|
||||
|
||||
// Convert to temporary observation for similarity check
|
||||
tempObs := &models.Observation{
|
||||
Title: sql.NullString{String: parsedObs.Title, Valid: parsedObs.Title != ""},
|
||||
Subtitle: sql.NullString{String: parsedObs.Subtitle, Valid: parsedObs.Subtitle != ""},
|
||||
Narrative: sql.NullString{String: parsedObs.Narrative, Valid: parsedObs.Narrative != ""},
|
||||
}
|
||||
|
||||
// Check for duplicates within this import batch
|
||||
if similarity.IsSimilarToAny(tempObs, importedObs, dedupThreshold) {
|
||||
skippedDupes++
|
||||
continue
|
||||
}
|
||||
|
||||
// Store observation
|
||||
obsID, _, err := s.observationStore.StoreObservation(
|
||||
r.Context(),
|
||||
fmt.Sprintf("bulk-import-%d", sessionID),
|
||||
req.Project,
|
||||
parsedObs,
|
||||
0, // prompt number
|
||||
0, // discovery tokens
|
||||
)
|
||||
if err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("observation %d: %v", i, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Sync to vector DB asynchronously with rate limiting
|
||||
if s.vectorSync != nil {
|
||||
s.asyncVectorSync(func() {
|
||||
// Use service context as parent to respect shutdown signals
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
obs, err := s.observationStore.GetObservationByID(ctx, obsID)
|
||||
if err == nil && obs != nil {
|
||||
if syncErr := s.vectorSync.SyncObservation(ctx, obs); syncErr != nil {
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Debug().Err(syncErr).Int64("id", obsID).Msg("Failed to sync observation during bulk import")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Track for deduplication of subsequent observations in this batch
|
||||
importedObs = append(importedObs, tempObs)
|
||||
imported++
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", req.Project).
|
||||
Int("imported", imported).
|
||||
Int("failed", failed).
|
||||
Int("skipped_duplicates", skippedDupes).
|
||||
Msg("Bulk import completed")
|
||||
|
||||
// Invalidate observation count cache after import
|
||||
if imported > 0 {
|
||||
if req.Project != "" {
|
||||
s.invalidateObsCountCache(req.Project)
|
||||
} else {
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast observation event for dashboard refresh
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "observation",
|
||||
"action": "bulk_import",
|
||||
"project": req.Project,
|
||||
"count": imported,
|
||||
})
|
||||
|
||||
writeJSON(w, BulkImportResponse{
|
||||
Imported: imported,
|
||||
Failed: failed,
|
||||
SkippedDuplicates: skippedDupes,
|
||||
Errors: errors,
|
||||
})
|
||||
}
|
||||
|
||||
// ArchiveRequest is the request body for archiving observations.
|
||||
type ArchiveRequest struct {
|
||||
Project string `json:"project,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
IDs []int64 `json:"ids,omitempty"`
|
||||
MaxAgeDays int `json:"max_age_days,omitempty"`
|
||||
}
|
||||
|
||||
// handleArchiveObservations archives observations by ID or by age.
|
||||
// Supports batch archival with error tracking per observation.
|
||||
func (s *Service) handleArchiveObservations(w http.ResponseWriter, r *http.Request) {
|
||||
var req ArchiveRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var archivedIDs []int64
|
||||
var failedIDs []int64
|
||||
var errors []string
|
||||
var err error
|
||||
|
||||
if len(req.IDs) > 0 {
|
||||
// Archive specific observations with parallel processing for large batches
|
||||
if len(req.IDs) > 5 {
|
||||
// Use parallel archival for batches larger than 5
|
||||
type archiveResult struct {
|
||||
err error
|
||||
id int64
|
||||
}
|
||||
results := make(chan archiveResult, len(req.IDs))
|
||||
|
||||
// Limit concurrency to avoid overwhelming the database
|
||||
sem := make(chan struct{}, 5)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, id := range req.IDs {
|
||||
wg.Add(1)
|
||||
go func(obsID int64) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire
|
||||
defer func() { <-sem }() // Release
|
||||
|
||||
archErr := s.observationStore.ArchiveObservation(r.Context(), obsID, req.Reason)
|
||||
results <- archiveResult{id: obsID, err: archErr}
|
||||
}(id)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
for res := range results {
|
||||
if res.err != nil {
|
||||
log.Warn().Err(res.err).Int64("id", res.id).Msg("Failed to archive observation")
|
||||
failedIDs = append(failedIDs, res.id)
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", res.id, res.err))
|
||||
} else {
|
||||
archivedIDs = append(archivedIDs, res.id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sequential for small batches
|
||||
for _, id := range req.IDs {
|
||||
if archErr := s.observationStore.ArchiveObservation(r.Context(), id, req.Reason); archErr != nil {
|
||||
log.Warn().Err(archErr).Int64("id", id).Msg("Failed to archive observation")
|
||||
failedIDs = append(failedIDs, id)
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, archErr))
|
||||
} else {
|
||||
archivedIDs = append(archivedIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if req.Project != "" || req.MaxAgeDays > 0 {
|
||||
// Archive by age
|
||||
archivedIDs, err = s.observationStore.ArchiveOldObservations(r.Context(), req.Project, req.MaxAgeDays, req.Reason)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to archive: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "either 'ids' or 'project'/'max_age_days' is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("project", req.Project).
|
||||
Int("archived", len(archivedIDs)).
|
||||
Int("failed", len(failedIDs)).
|
||||
Msg("Observations archived")
|
||||
|
||||
// Invalidate cache if any observations were archived
|
||||
if len(archivedIDs) > 0 {
|
||||
if req.Project != "" {
|
||||
s.invalidateObsCountCache(req.Project)
|
||||
} else {
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"archived_count": len(archivedIDs),
|
||||
"archived_ids": archivedIDs,
|
||||
}
|
||||
if len(failedIDs) > 0 {
|
||||
response["failed_count"] = len(failedIDs)
|
||||
response["failed_ids"] = failedIDs
|
||||
response["errors"] = errors
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleUnarchiveObservation restores an archived observation.
|
||||
func (s *Service) handleUnarchiveObservation(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid observation id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.observationStore.UnarchiveObservation(r.Context(), id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate all caches since we don't know the project
|
||||
s.invalidateAllObsCountCache()
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetArchivedObservations returns archived observations.
|
||||
func (s *Service) handleGetArchivedObservations(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
limit := gorm.ParseLimitParam(r, DefaultObservationsLimit)
|
||||
|
||||
observations, err := s.observationStore.GetArchivedObservations(r.Context(), project, limit)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if observations == nil {
|
||||
observations = []*models.Observation{}
|
||||
}
|
||||
|
||||
writeJSON(w, observations)
|
||||
}
|
||||
|
||||
// handleGetArchivalStats returns archival statistics.
|
||||
func (s *Service) handleGetArchivalStats(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
|
||||
stats, err := s.observationStore.GetArchivalStats(r.Context(), project)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// handleExportObservations exports observations in JSON or CSV format.
|
||||
// Supports query parameters: project, format (json/csv), scope, type, limit.
|
||||
func (s *Service) handleExportObservations(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
scope := r.URL.Query().Get("scope") // project, global, or empty for all
|
||||
obsType := r.URL.Query().Get("type") // bugfix, feature, etc.
|
||||
limit := gorm.ParseLimitParamWithMax(r, 1000, 5000) // Higher limit for exports, capped at 5000
|
||||
|
||||
// Validate format
|
||||
if format != "json" && format != "csv" {
|
||||
http.Error(w, "format must be 'json' or 'csv'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get observations with filters
|
||||
ctx := r.Context()
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
|
||||
} else {
|
||||
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply additional filters
|
||||
if scope != "" || obsType != "" {
|
||||
filtered := make([]*models.Observation, 0, len(observations))
|
||||
for _, obs := range observations {
|
||||
if scope != "" && string(obs.Scope) != scope {
|
||||
continue
|
||||
}
|
||||
if obsType != "" && string(obs.Type) != obsType {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, obs)
|
||||
}
|
||||
observations = filtered
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
filename := fmt.Sprintf("observations-%s.%s", timestamp, format)
|
||||
if project != "" {
|
||||
// Sanitize project name for filename
|
||||
sanitized := strings.ReplaceAll(project, "/", "_")
|
||||
sanitized = strings.ReplaceAll(sanitized, "\\", "_")
|
||||
if len(sanitized) > 50 {
|
||||
sanitized = sanitized[:50]
|
||||
}
|
||||
filename = fmt.Sprintf("observations-%s-%s.%s", sanitized, timestamp, format)
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "csv":
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
s.writeObservationsCSV(w, observations)
|
||||
default: // json
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
|
||||
writeJSON(w, map[string]any{
|
||||
"exported_at": time.Now().Format(time.RFC3339),
|
||||
"project": project,
|
||||
"count": len(observations),
|
||||
"observations": observations,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// writeObservationsCSV writes observations in CSV format.
|
||||
// Uses fmt.Fprintf directly to avoid intermediate string allocations.
|
||||
func (s *Service) writeObservationsCSV(w http.ResponseWriter, observations []*models.Observation) {
|
||||
// Write CSV header
|
||||
_, _ = io.WriteString(w, "id,type,scope,project,title,subtitle,narrative,concepts,facts,created_at,importance_score\n")
|
||||
|
||||
for _, obs := range observations {
|
||||
// Write directly to avoid string allocation per row
|
||||
_, _ = fmt.Fprintf(w, "%d,%s,%s,%s,%s,%s,%s,%s,%s,%s,%.2f\n",
|
||||
obs.ID,
|
||||
obs.Type,
|
||||
obs.Scope,
|
||||
escapeCsvField(obs.Project),
|
||||
escapeCsvField(obs.Title.String),
|
||||
escapeCsvField(obs.Subtitle.String),
|
||||
escapeCsvField(obs.Narrative.String),
|
||||
escapeCsvField(strings.Join(obs.Concepts, ";")),
|
||||
escapeCsvField(strings.Join(obs.Facts, ";")),
|
||||
obs.CreatedAt,
|
||||
obs.ImportanceScore,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// escapeCsvField escapes a field for CSV output.
|
||||
func escapeCsvField(s string) string {
|
||||
// If field contains comma, quote, or newline, wrap in quotes and escape quotes
|
||||
if strings.ContainsAny(s, ",\"\n\r") {
|
||||
s = strings.ReplaceAll(s, "\"", "\"\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// BulkStatusRequest represents a request to update status for multiple observations.
|
||||
type BulkStatusRequest struct {
|
||||
Action string `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
IDs []int64 `json:"ids"`
|
||||
Feedback int `json:"feedback,omitempty"`
|
||||
}
|
||||
|
||||
// handleBulkStatusUpdate updates status for multiple observations in one request.
|
||||
func (s *Service) handleBulkStatusUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
var req BulkStatusRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.IDs) == 0 {
|
||||
http.Error(w, "ids is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.IDs) > 500 {
|
||||
http.Error(w, "maximum 500 ids per request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var updated, failed int
|
||||
var errors []string
|
||||
|
||||
switch req.Action {
|
||||
case "supersede":
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.MarkAsSuperseded(ctx, id); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
case "archive":
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.ArchiveObservation(ctx, id, req.Reason); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
case "set_feedback":
|
||||
if req.Feedback < -1 || req.Feedback > 1 {
|
||||
http.Error(w, "feedback must be -1, 0, or 1", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
for _, id := range req.IDs {
|
||||
if err := s.observationStore.UpdateObservationFeedback(ctx, id, req.Feedback); err != nil {
|
||||
failed++
|
||||
errors = append(errors, fmt.Sprintf("id %d: %v", id, err))
|
||||
} else {
|
||||
updated++
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
http.Error(w, "action must be 'supersede', 'archive', or 'set_feedback'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate cache for archive action (affects observation counts)
|
||||
if req.Action == "archive" && updated > 0 {
|
||||
// No project info available, invalidate all caches
|
||||
s.invalidateAllObsCountCache()
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"action": req.Action,
|
||||
"updated": updated,
|
||||
"failed": failed,
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
response["errors"] = errors
|
||||
}
|
||||
|
||||
writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleFindDuplicates finds potential duplicate observations using similarity clustering.
|
||||
// Returns groups of similar observations that may be candidates for merging or archival.
|
||||
func (s *Service) handleFindDuplicates(w http.ResponseWriter, r *http.Request) {
|
||||
project := r.URL.Query().Get("project")
|
||||
thresholdStr := r.URL.Query().Get("threshold")
|
||||
limit := gorm.ParseLimitParam(r, 100)
|
||||
|
||||
// Parse threshold (default 0.6 = 60% similarity)
|
||||
threshold := 0.6
|
||||
if thresholdStr != "" {
|
||||
if t, err := strconv.ParseFloat(thresholdStr, 64); err == nil && t > 0 && t < 1 {
|
||||
threshold = t
|
||||
}
|
||||
}
|
||||
|
||||
// Get recent observations
|
||||
ctx := r.Context()
|
||||
var observations []*models.Observation
|
||||
var err error
|
||||
|
||||
if project != "" {
|
||||
observations, _, err = s.observationStore.GetObservationsByProjectStrictPaginated(ctx, project, limit, 0)
|
||||
} else {
|
||||
observations, _, err = s.observationStore.GetAllRecentObservationsPaginated(ctx, limit, 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if len(observations) < 2 {
|
||||
writeJSON(w, map[string]any{
|
||||
"duplicate_groups": []any{},
|
||||
"total_checked": len(observations),
|
||||
"threshold": threshold,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Find duplicates using similarity comparison
|
||||
type duplicateGroup struct {
|
||||
Observations []map[string]any `json:"observations"`
|
||||
Similarity float64 `json:"similarity"`
|
||||
}
|
||||
|
||||
groups := []duplicateGroup{}
|
||||
processed := make(map[int64]bool)
|
||||
|
||||
for i, obs1 := range observations {
|
||||
if processed[obs1.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
terms1 := similarity.ExtractObservationTerms(obs1)
|
||||
if len(terms1) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
group := duplicateGroup{
|
||||
Observations: []map[string]any{obs1.ToMap()},
|
||||
Similarity: 1.0,
|
||||
}
|
||||
|
||||
for j := i + 1; j < len(observations); j++ {
|
||||
obs2 := observations[j]
|
||||
if processed[obs2.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
terms2 := similarity.ExtractObservationTerms(obs2)
|
||||
sim := similarity.JaccardSimilarity(terms1, terms2)
|
||||
|
||||
if sim >= threshold {
|
||||
obsMap := obs2.ToMap()
|
||||
obsMap["similarity_to_first"] = sim
|
||||
group.Observations = append(group.Observations, obsMap)
|
||||
group.Similarity = min(group.Similarity, sim)
|
||||
processed[obs2.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(group.Observations) > 1 {
|
||||
processed[obs1.ID] = true
|
||||
groups = append(groups, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort groups by size (largest first)
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
return len(groups[i].Observations) > len(groups[j].Observations)
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"duplicate_groups": groups,
|
||||
"total_checked": len(observations),
|
||||
"groups_found": len(groups),
|
||||
"threshold": threshold,
|
||||
"project": project,
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FeedbackRequest represents a user feedback submission.
|
||||
@@ -311,8 +312,7 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ
|
||||
// Run recalculation in background
|
||||
go func() {
|
||||
if err := recalculator.RecalculateNow(r.Context()); err != nil {
|
||||
// Log error but don't block response
|
||||
_ = err // Explicitly ignore - background operation
|
||||
log.Warn().Err(err).Msg("Background score recalculation failed")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -345,14 +345,18 @@ func (s *Service) incrementRetrievalCounts(ids []int64) {
|
||||
}
|
||||
|
||||
// Increment in background to not block response
|
||||
// Use service context to respect shutdown signals
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
// Create a new context with timeout for the background operation
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer s.wg.Done()
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := store.IncrementRetrievalCount(ctx, ids); err != nil {
|
||||
// Log but don't fail - this is a background operation
|
||||
_ = err // Explicitly ignore - background operation
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Debug().Err(err).Msg("Failed to increment retrieval counts")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
// Package worker provides session-related HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/privacy"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/internal/worker/session"
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// SessionInitRequest is the request body for session initialization.
|
||||
type SessionInitRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
Prompt string `json:"prompt"`
|
||||
MatchedObservations int `json:"matchedObservations"`
|
||||
}
|
||||
|
||||
// SessionInitResponse is the response for session initialization.
|
||||
type SessionInitResponse struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
SessionDBID int64 `json:"sessionDbId"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
Skipped bool `json:"skipped,omitempty"`
|
||||
}
|
||||
|
||||
// DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions.
|
||||
// If the same prompt text is seen within this window, it's considered a duplicate hook invocation.
|
||||
const DuplicatePromptWindowSeconds = 10
|
||||
|
||||
// handleSessionInit handles session initialization from user-prompt hook.
|
||||
// This handler is idempotent - duplicate requests within a short time window
|
||||
// return the existing prompt data without creating duplicates.
|
||||
func (s *Service) handleSessionInit(w http.ResponseWriter, r *http.Request) {
|
||||
var req SessionInitRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Privacy check
|
||||
if privacy.IsEntirelyPrivate(req.Prompt) {
|
||||
// Create session but skip processing
|
||||
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
promptNum, _ := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
Skipped: true,
|
||||
Reason: "private",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Clean prompt
|
||||
cleanedPrompt := privacy.Clean(req.Prompt)
|
||||
|
||||
// DUPLICATE DETECTION: Check if this exact prompt was already saved recently.
|
||||
// This prevents the bug where the hook fires multiple times for the same user action,
|
||||
// creating many duplicate prompts with incrementing numbers.
|
||||
if existingID, existingNum, found := s.promptStore.FindRecentPromptByText(r.Context(), req.ClaudeSessionID, cleanedPrompt, DuplicatePromptWindowSeconds); found {
|
||||
// Get or create session (idempotent)
|
||||
sessionID, _ := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
|
||||
|
||||
log.Debug().
|
||||
Int64("sessionId", sessionID).
|
||||
Int("promptNumber", existingNum).
|
||||
Int64("promptId", existingID).
|
||||
Msg("Duplicate prompt detected - returning existing")
|
||||
|
||||
// Return existing prompt data without incrementing or saving again
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: existingNum,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create session (idempotent)
|
||||
sessionID, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, cleanedPrompt)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Increment prompt counter
|
||||
promptNum, err := s.sessionStore.IncrementPromptCounter(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Save user prompt with matched observation count
|
||||
promptID, err := s.promptStore.SaveUserPromptWithMatches(r.Context(), req.ClaudeSessionID, promptNum, cleanedPrompt, req.MatchedObservations)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to save user prompt")
|
||||
// Non-fatal: continue with session initialization
|
||||
} else if s.vectorSync != nil {
|
||||
// Sync to vector DB asynchronously (non-blocking)
|
||||
now := time.Now()
|
||||
promptWithSession := &models.UserPromptWithSession{
|
||||
UserPrompt: models.UserPrompt{
|
||||
ID: promptID,
|
||||
ClaudeSessionID: req.ClaudeSessionID,
|
||||
PromptNumber: promptNum,
|
||||
PromptText: cleanedPrompt,
|
||||
MatchedObservations: req.MatchedObservations,
|
||||
CreatedAt: now.Format(time.RFC3339),
|
||||
CreatedAtEpoch: now.UnixMilli(),
|
||||
},
|
||||
Project: req.Project,
|
||||
SDKSessionID: req.ClaudeSessionID,
|
||||
}
|
||||
s.asyncVectorSync(func() {
|
||||
// Use service context as parent to respect shutdown signals
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := s.vectorSync.SyncUserPrompt(ctx, promptWithSession); err != nil {
|
||||
if s.ctx.Err() == nil { // Don't log during shutdown
|
||||
log.Warn().Err(err).Int64("id", promptID).Msg("Failed to sync user prompt to sqlite-vec")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Int64("sessionId", sessionID).
|
||||
Int("promptNumber", promptNum).
|
||||
Str("project", req.Project).
|
||||
Msg("Session initialized")
|
||||
|
||||
// Broadcast prompt event for dashboard refresh
|
||||
s.sseBroadcaster.Broadcast(map[string]any{
|
||||
"type": "prompt",
|
||||
"action": "created",
|
||||
"project": req.Project,
|
||||
})
|
||||
|
||||
writeJSON(w, SessionInitResponse{
|
||||
SessionDBID: sessionID,
|
||||
PromptNumber: promptNum,
|
||||
})
|
||||
}
|
||||
|
||||
// SessionStartRequest is the request body for starting SDK agent.
|
||||
type SessionStartRequest struct {
|
||||
UserPrompt string `json:"userPrompt"`
|
||||
PromptNumber int `json:"promptNumber"`
|
||||
}
|
||||
|
||||
// handleSessionStart handles SDK agent session start.
|
||||
func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SessionStartRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize session in manager
|
||||
sess, err := s.sessionManager.InitializeSession(r.Context(), id, req.UserPrompt, req.PromptNumber)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Session is now registered. Observations will be processed
|
||||
// asynchronously by the background queue processor (processQueue in service.go).
|
||||
log.Info().
|
||||
Int64("sessionId", id).
|
||||
Int("promptNumber", req.PromptNumber).
|
||||
Msg("SDK agent session initialized")
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// ObservationRequest is the request body for posting observations.
|
||||
type ObservationRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ToolInput any `json:"tool_input"`
|
||||
ToolResponse any `json:"tool_response"`
|
||||
CWD string `json:"cwd"`
|
||||
}
|
||||
|
||||
// handleObservation handles observation posting from post-tool-use hook.
|
||||
func (s *Service) handleObservation(w http.ResponseWriter, r *http.Request) {
|
||||
var req ObservationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if sess == nil {
|
||||
// Create session on-the-fly with project from request
|
||||
id, err := s.sessionStore.CreateSDKSession(r.Context(), req.ClaudeSessionID, req.Project, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sess, _ = s.sessionStore.GetSessionByID(r.Context(), id)
|
||||
}
|
||||
|
||||
// Queue observation
|
||||
if err := s.sessionManager.QueueObservation(r.Context(), sess.ID, session.ObservationData{
|
||||
ToolName: req.ToolName,
|
||||
ToolInput: req.ToolInput,
|
||||
ToolResponse: req.ToolResponse,
|
||||
CWD: req.CWD,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// SubagentCompleteRequest is the request body for subagent completion.
|
||||
type SubagentCompleteRequest struct {
|
||||
ClaudeSessionID string `json:"claudeSessionId"`
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// handleSubagentComplete handles subagent/Task completion notifications.
|
||||
// This triggers immediate processing of any queued observations from the subagent.
|
||||
func (s *Service) handleSubagentComplete(w http.ResponseWriter, r *http.Request) {
|
||||
var req SubagentCompleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Find session
|
||||
sess, err := s.sessionStore.FindAnySDKSession(r.Context(), req.ClaudeSessionID)
|
||||
if err != nil || sess == nil {
|
||||
// Session not found - subagent may have been in a different context
|
||||
log.Debug().
|
||||
Str("claudeSessionId", req.ClaudeSessionID).
|
||||
Msg("Subagent complete - no active session found")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger immediate processing of queued observations
|
||||
messages := s.sessionManager.DrainMessages(sess.ID)
|
||||
if len(messages) > 0 && s.processor != nil {
|
||||
log.Info().
|
||||
Int64("sessionId", sess.ID).
|
||||
Int("messages", len(messages)).
|
||||
Msg("Processing queued observations from subagent")
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Type == session.MessageTypeObservation && msg.Observation != nil {
|
||||
err := s.processor.ProcessObservation(
|
||||
r.Context(),
|
||||
sess.SDKSessionID.String,
|
||||
sess.Project,
|
||||
msg.Observation.ToolName,
|
||||
msg.Observation.ToolInput,
|
||||
msg.Observation.ToolResponse,
|
||||
msg.Observation.PromptNumber,
|
||||
msg.Observation.CWD,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("tool", msg.Observation.ToolName).
|
||||
Msg("Failed to process subagent observation")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleGetSessionByClaudeID looks up a session by Claude session ID.
|
||||
func (s *Service) handleGetSessionByClaudeID(w http.ResponseWriter, r *http.Request) {
|
||||
claudeSessionID := r.URL.Query().Get("claudeSessionId")
|
||||
if claudeSessionID == "" {
|
||||
http.Error(w, "claudeSessionId required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := s.sessionStore.FindAnySDKSession(r.Context(), claudeSessionID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, session)
|
||||
}
|
||||
|
||||
// SummarizeRequest is the request body for summarize requests.
|
||||
type SummarizeRequest struct {
|
||||
LastUserMessage string `json:"lastUserMessage"`
|
||||
LastAssistantMessage string `json:"lastAssistantMessage"`
|
||||
}
|
||||
|
||||
// handleSummarize handles summarize requests from stop hook.
|
||||
func (s *Service) handleSummarize(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid session id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req SummarizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Queue summarize request
|
||||
if err := s.sessionManager.QueueSummarize(r.Context(), id, req.LastUserMessage, req.LastAssistantMessage); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.broadcastProcessingStatus()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func testService(t *testing.T) (*Service, func()) {
|
||||
cancel: cancel,
|
||||
startTime: time.Now(),
|
||||
retrievalStats: make(map[string]*RetrievalStats),
|
||||
cachedObsCounts: make(map[string]cachedCount),
|
||||
}
|
||||
|
||||
svc.setupRoutes()
|
||||
@@ -345,11 +346,13 @@ func TestHandleGetObservations_Limit(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Parse as generic JSON array since the model uses custom marshaling
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
// Parse as object with observations key (API returns wrapped response)
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "expected observations array in response")
|
||||
assert.Len(t, observations, 10)
|
||||
}
|
||||
|
||||
@@ -1135,10 +1138,13 @@ func TestHandleGetObservations_DefaultLimit(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "expected observations array in response")
|
||||
|
||||
// Should return default limit (100)
|
||||
assert.LessOrEqual(t, len(observations), DefaultObservationsLimit)
|
||||
}
|
||||
@@ -1159,10 +1165,12 @@ func TestHandleGetObservations_FilterByProject(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "expected observations array in response")
|
||||
assert.Len(t, observations, 2)
|
||||
}
|
||||
|
||||
@@ -1412,10 +1420,12 @@ func TestHandleGetObservations(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var observations []map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &observations)
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
observations, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "expected observations array in response")
|
||||
assert.GreaterOrEqual(t, len(observations), 2)
|
||||
}
|
||||
|
||||
@@ -2697,10 +2707,13 @@ func TestHandleGetObservations_EmptyResult(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Should return empty array, not null
|
||||
var obs []interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &obs)
|
||||
// Should return empty array within observations key, not null
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
obs, ok := response["observations"].([]interface{})
|
||||
require.True(t, ok, "expected observations array in response")
|
||||
assert.NotNil(t, obs)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
// Package worker provides update and restart HTTP handlers.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// handleUpdateCheck checks for available updates.
|
||||
func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) {
|
||||
info, err := s.updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSON(w, info)
|
||||
}
|
||||
|
||||
// handleUpdateApply downloads and applies an available update.
|
||||
func (s *Service) handleUpdateApply(w http.ResponseWriter, r *http.Request) {
|
||||
// First check for update
|
||||
info, err := s.updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !info.Available {
|
||||
writeJSON(w, map[string]any{
|
||||
"success": false,
|
||||
"message": "No update available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply update in background with tracking for graceful shutdown
|
||||
s.wg.Go(func() {
|
||||
if err := s.updater.ApplyUpdate(s.ctx, info); err != nil {
|
||||
log.Error().Err(err).Msg("Update failed")
|
||||
}
|
||||
})
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Update started",
|
||||
"version": info.LatestVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateStatus returns the current update status.
|
||||
func (s *Service) handleUpdateStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status := s.updater.GetStatus()
|
||||
writeJSON(w, status)
|
||||
}
|
||||
|
||||
// ComponentHealth represents the health status of a single component.
|
||||
type ComponentHealth struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"` // "healthy", "degraded", "unhealthy"
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// SelfCheckResponse contains the health status of all components.
|
||||
type SelfCheckResponse struct {
|
||||
Overall string `json:"overall"` // "healthy", "degraded", "unhealthy"
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
Components []ComponentHealth `json:"components"`
|
||||
}
|
||||
|
||||
// handleSelfCheck returns the health status of all components.
|
||||
func (s *Service) handleSelfCheck(w http.ResponseWriter, r *http.Request) {
|
||||
components := []ComponentHealth{}
|
||||
overall := "healthy"
|
||||
|
||||
// Check Worker Service
|
||||
workerStatus := ComponentHealth{Name: "Worker Service", Status: "healthy"}
|
||||
if !s.ready.Load() {
|
||||
if err := s.GetInitError(); err != nil {
|
||||
workerStatus.Status = "unhealthy"
|
||||
workerStatus.Message = err.Error()
|
||||
overall = "unhealthy"
|
||||
} else {
|
||||
workerStatus.Status = "degraded"
|
||||
workerStatus.Message = "Initializing"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
}
|
||||
components = append(components, workerStatus)
|
||||
|
||||
// Check SQLite Database
|
||||
dbStatus := ComponentHealth{Name: "SQLite Database", Status: "healthy"}
|
||||
if s.store == nil {
|
||||
dbStatus.Status = "unhealthy"
|
||||
dbStatus.Message = "Not initialized"
|
||||
overall = "unhealthy"
|
||||
} else if err := s.store.Ping(); err != nil {
|
||||
dbStatus.Status = "unhealthy"
|
||||
dbStatus.Message = err.Error()
|
||||
overall = "unhealthy"
|
||||
}
|
||||
components = append(components, dbStatus)
|
||||
|
||||
// Check Vector DB (sqlite-vec)
|
||||
vectorStatus := ComponentHealth{Name: "Vector DB", Status: "healthy"}
|
||||
if s.vectorClient == nil {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not configured"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if !s.vectorClient.IsConnected() {
|
||||
vectorStatus.Status = "degraded"
|
||||
vectorStatus.Message = "Not connected"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
components = append(components, vectorStatus)
|
||||
|
||||
// Check SDK Processor
|
||||
sdkStatus := ComponentHealth{Name: "SDK Processor", Status: "healthy"}
|
||||
if s.processor == nil {
|
||||
sdkStatus.Status = "degraded"
|
||||
sdkStatus.Message = "Not initialized"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if !s.processor.IsAvailable() {
|
||||
sdkStatus.Status = "degraded"
|
||||
sdkStatus.Message = "Claude CLI not available"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
}
|
||||
components = append(components, sdkStatus)
|
||||
|
||||
// Check SSE Broadcaster
|
||||
sseStatus := ComponentHealth{Name: "SSE Broadcaster", Status: "healthy"}
|
||||
if s.sseBroadcaster == nil {
|
||||
sseStatus.Status = "unhealthy"
|
||||
sseStatus.Message = "Not initialized"
|
||||
overall = "unhealthy"
|
||||
}
|
||||
components = append(components, sseStatus)
|
||||
|
||||
// Check Cross-Encoder Reranker
|
||||
rerankerStatus := ComponentHealth{Name: "Cross-Encoder Reranker", Status: "healthy"}
|
||||
if !s.config.RerankingEnabled {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Disabled in config"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else if s.reranker == nil {
|
||||
rerankerStatus.Status = "degraded"
|
||||
rerankerStatus.Message = "Not initialized"
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
// Verify reranker is functional using Score
|
||||
_, normalizedScore, err := s.reranker.Score("test query", "test document")
|
||||
if err != nil {
|
||||
rerankerStatus.Status = "unhealthy"
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check failed: %v", err)
|
||||
if overall == "healthy" {
|
||||
overall = "degraded"
|
||||
}
|
||||
} else {
|
||||
rerankerStatus.Message = fmt.Sprintf("Score check passed (%.4f)", normalizedScore)
|
||||
}
|
||||
}
|
||||
components = append(components, rerankerStatus)
|
||||
|
||||
// Calculate uptime
|
||||
uptime := time.Since(s.startTime).Round(time.Second).String()
|
||||
|
||||
writeJSON(w, SelfCheckResponse{
|
||||
Overall: overall,
|
||||
Version: s.version,
|
||||
Uptime: uptime,
|
||||
Components: components,
|
||||
})
|
||||
}
|
||||
|
||||
// handleUpdateRestart restarts the worker with the new binary (after update).
|
||||
func (s *Service) handleUpdateRestart(w http.ResponseWriter, r *http.Request) {
|
||||
status := s.updater.GetStatus()
|
||||
if status.State != "done" {
|
||||
http.Error(w, "no update has been applied", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Send response before restarting
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Restarting worker...",
|
||||
})
|
||||
|
||||
// Flush the response
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Restart in background after response is sent
|
||||
go func() {
|
||||
if err := s.updater.Restart(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to restart worker")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleRestart restarts the worker process (general restart, not tied to update).
|
||||
func (s *Service) handleRestart(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msg("Manual restart requested via API")
|
||||
|
||||
// Send response before restarting
|
||||
writeJSON(w, map[string]any{
|
||||
"success": true,
|
||||
"message": "Restarting worker...",
|
||||
"version": s.version,
|
||||
})
|
||||
|
||||
// Flush the response
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
// Restart in background after response is sent
|
||||
go func() {
|
||||
// Small delay to ensure response is sent
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := s.updater.Restart(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to restart worker")
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// requestIDKey is the context key for request IDs.
|
||||
type requestIDKey struct{}
|
||||
|
||||
// projectNamePattern validates project names to prevent path traversal.
|
||||
var projectNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./-]+$`)
|
||||
|
||||
// allowedOrigins is the whitelist of origins allowed for CORS.
|
||||
// Uses exact matching to prevent bypass attacks like "evil-localhost.com".
|
||||
var allowedOrigins = map[string]bool{
|
||||
"http://localhost": true,
|
||||
"http://localhost:3000": true,
|
||||
"http://localhost:5173": true, // Vite dev server
|
||||
"http://localhost:37778": true, // Dashboard UI
|
||||
"http://127.0.0.1": true,
|
||||
"http://127.0.0.1:3000": true,
|
||||
"http://127.0.0.1:5173": true,
|
||||
"http://127.0.0.1:37778": true,
|
||||
}
|
||||
|
||||
// SecurityHeaders middleware adds essential security headers to all responses.
|
||||
// These protect against common web vulnerabilities.
|
||||
func SecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent clickjacking
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Prevent MIME type sniffing
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Enable XSS filter
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Restrict referrer information
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Content Security Policy - restrict to self
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'self'")
|
||||
|
||||
// Permissions Policy - disable unnecessary features
|
||||
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
||||
|
||||
// CORS: Use exact match whitelist to prevent bypass attacks
|
||||
origin := r.Header.Get("Origin")
|
||||
if allowedOrigins[origin] {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-Auth-Token, Authorization, X-Request-ID")
|
||||
}
|
||||
|
||||
// Handle preflight requests
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// MaxBodySize middleware limits the size of incoming request bodies.
|
||||
// This prevents denial of service attacks via large payloads.
|
||||
func MaxBodySize(maxBytes int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > maxBytes {
|
||||
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TokenAuth provides simple token-based authentication for localhost services.
|
||||
// The token is generated at startup and must be provided in the X-Auth-Token header.
|
||||
type TokenAuth struct {
|
||||
ExemptPaths map[string]bool
|
||||
token string
|
||||
mu sync.RWMutex
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewTokenAuth creates a new TokenAuth with a randomly generated token.
|
||||
// If enabled is false, authentication is skipped (useful for development).
|
||||
func NewTokenAuth(enabled bool) (*TokenAuth, error) {
|
||||
ta := &TokenAuth{
|
||||
enabled: enabled,
|
||||
ExemptPaths: map[string]bool{
|
||||
"/health": true,
|
||||
"/api/health": true,
|
||||
"/api/ready": true,
|
||||
},
|
||||
}
|
||||
|
||||
if enabled {
|
||||
// Generate 32-byte random token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ta.token = hex.EncodeToString(tokenBytes)
|
||||
}
|
||||
|
||||
return ta, nil
|
||||
}
|
||||
|
||||
// Token returns the authentication token.
|
||||
// Returns empty string if authentication is disabled.
|
||||
func (ta *TokenAuth) Token() string {
|
||||
ta.mu.RLock()
|
||||
defer ta.mu.RUnlock()
|
||||
return ta.token
|
||||
}
|
||||
|
||||
// IsEnabled returns whether token authentication is enabled.
|
||||
func (ta *TokenAuth) IsEnabled() bool {
|
||||
ta.mu.RLock()
|
||||
defer ta.mu.RUnlock()
|
||||
return ta.enabled
|
||||
}
|
||||
|
||||
// Middleware returns HTTP middleware that enforces token authentication.
|
||||
func (ta *TokenAuth) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ta.mu.RLock()
|
||||
enabled := ta.enabled
|
||||
token := ta.token
|
||||
exempt := ta.ExemptPaths[r.URL.Path]
|
||||
ta.mu.RUnlock()
|
||||
|
||||
// Skip auth if disabled or path is exempt
|
||||
if !enabled || exempt {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for token in header
|
||||
providedToken := r.Header.Get("X-Auth-Token")
|
||||
if providedToken == "" {
|
||||
// Also check Authorization header with Bearer scheme
|
||||
auth := r.Header.Get("Authorization")
|
||||
if bearer, found := strings.CutPrefix(auth, "Bearer "); found {
|
||||
providedToken = bearer
|
||||
}
|
||||
}
|
||||
|
||||
if providedToken != token {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ExpensiveOperationLimiter provides stricter rate limiting for expensive operations.
|
||||
// It wraps the base per-client rate limiter with additional per-operation limits.
|
||||
type ExpensiveOperationLimiter struct {
|
||||
// Track last execution time per operation type
|
||||
lastRebuild int64 // Unix timestamp
|
||||
rebuildCooldown int64 // Minimum seconds between rebuilds
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewExpensiveOperationLimiter creates a limiter for expensive operations.
|
||||
func NewExpensiveOperationLimiter() *ExpensiveOperationLimiter {
|
||||
return &ExpensiveOperationLimiter{
|
||||
rebuildCooldown: 300, // 5 minutes between rebuilds
|
||||
}
|
||||
}
|
||||
|
||||
// CanRebuild checks if a vector rebuild operation is allowed.
|
||||
// Returns false if a rebuild was triggered too recently.
|
||||
func (eol *ExpensiveOperationLimiter) CanRebuild() bool {
|
||||
eol.mu.Lock()
|
||||
defer eol.mu.Unlock()
|
||||
|
||||
now := unixNow()
|
||||
if now-eol.lastRebuild < eol.rebuildCooldown {
|
||||
return false
|
||||
}
|
||||
eol.lastRebuild = now
|
||||
return true
|
||||
}
|
||||
|
||||
// unixNow returns current Unix timestamp.
|
||||
// Separated for easier testing.
|
||||
func unixNow() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// RequestID middleware adds a unique request ID to each request.
|
||||
// The ID is added to the context and response headers for tracing.
|
||||
func RequestID(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check for existing request ID from client
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
// Generate new request ID
|
||||
idBytes := make([]byte, 8)
|
||||
if _, err := rand.Read(idBytes); err == nil {
|
||||
requestID = hex.EncodeToString(idBytes)
|
||||
} else {
|
||||
requestID = fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
}
|
||||
}
|
||||
|
||||
// Add to response header
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
// Add to context
|
||||
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the context.
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RequireJSONContentType middleware validates that POST/PUT/PATCH requests
|
||||
// have application/json Content-Type header.
|
||||
func RequireJSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only check for methods that typically have bodies
|
||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
// Allow empty Content-Type for requests without body
|
||||
if ct != "" && !strings.HasPrefix(ct, "application/json") {
|
||||
http.Error(w, "Content-Type must be application/json", http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateProjectName checks if a project name is safe to use.
|
||||
// Returns an error if the name contains path traversal or invalid characters.
|
||||
func ValidateProjectName(project string) error {
|
||||
if project == "" {
|
||||
return nil // Empty is allowed (means no filter)
|
||||
}
|
||||
|
||||
// Check for path traversal
|
||||
if strings.Contains(project, "..") {
|
||||
return fmt.Errorf("invalid project name: path traversal detected")
|
||||
}
|
||||
|
||||
// Check for valid characters
|
||||
if !projectNamePattern.MatchString(project) {
|
||||
return fmt.Errorf("invalid project name: only alphanumeric, underscore, dash, dot, and slash allowed")
|
||||
}
|
||||
|
||||
// Max length check
|
||||
if len(project) > 500 {
|
||||
return fmt.Errorf("project name too long (max 500 chars)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BulkOperationLimiter provides rate limiting for bulk operations.
|
||||
// Prevents DoS via repeated bulk requests.
|
||||
type BulkOperationLimiter struct {
|
||||
lastBulkOp int64 // Unix timestamp
|
||||
cooldown int64 // Minimum seconds between operations
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBulkOperationLimiter creates a limiter for bulk operations.
|
||||
func NewBulkOperationLimiter(cooldownSeconds int64) *BulkOperationLimiter {
|
||||
return &BulkOperationLimiter{
|
||||
cooldown: cooldownSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// CanExecute checks if a bulk operation is allowed.
|
||||
// Returns false if a bulk operation was triggered too recently.
|
||||
func (bol *BulkOperationLimiter) CanExecute() bool {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
|
||||
now := unixNow()
|
||||
if now-bol.lastBulkOp < bol.cooldown {
|
||||
return false
|
||||
}
|
||||
bol.lastBulkOp = now
|
||||
return true
|
||||
}
|
||||
|
||||
// TimeSinceLastOp returns seconds since the last bulk operation.
|
||||
func (bol *BulkOperationLimiter) TimeSinceLastOp() int64 {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
return unixNow() - bol.lastBulkOp
|
||||
}
|
||||
|
||||
// CooldownRemaining returns seconds remaining in the cooldown period.
|
||||
// Returns 0 if no cooldown is active.
|
||||
func (bol *BulkOperationLimiter) CooldownRemaining() int64 {
|
||||
bol.mu.Lock()
|
||||
defer bol.mu.Unlock()
|
||||
|
||||
remaining := bol.cooldown - (unixNow() - bol.lastBulkOp)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
@@ -0,0 +1,515 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check all security headers are set
|
||||
tests := []struct {
|
||||
header string
|
||||
expected string
|
||||
}{
|
||||
{"X-Frame-Options", "DENY"},
|
||||
{"X-Content-Type-Options", "nosniff"},
|
||||
{"X-XSS-Protection", "1; mode=block"},
|
||||
{"Referrer-Policy", "strict-origin-when-cross-origin"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := rr.Header().Get(tt.header); got != tt.expected {
|
||||
t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CORS(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectedOrigin string
|
||||
expectCORS bool
|
||||
}{
|
||||
{
|
||||
name: "localhost:37778 origin allowed",
|
||||
origin: "http://localhost:37778",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://localhost:37778",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1:5173 origin allowed",
|
||||
origin: "http://127.0.0.1:5173",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://127.0.0.1:5173",
|
||||
},
|
||||
{
|
||||
name: "localhost without port allowed",
|
||||
origin: "http://localhost",
|
||||
expectCORS: true,
|
||||
expectedOrigin: "http://localhost",
|
||||
},
|
||||
{
|
||||
name: "external origin blocked",
|
||||
origin: "http://evil.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "evil-localhost.com bypass attempt blocked",
|
||||
origin: "http://evil-localhost.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "localhost subdomain bypass attempt blocked",
|
||||
origin: "http://localhost.evil.com",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "unknown localhost port blocked",
|
||||
origin: "http://localhost:9999",
|
||||
expectCORS: false,
|
||||
},
|
||||
{
|
||||
name: "no origin header",
|
||||
origin: "",
|
||||
expectCORS: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
cors := rr.Header().Get("Access-Control-Allow-Origin")
|
||||
if tt.expectCORS {
|
||||
if cors != tt.expectedOrigin {
|
||||
t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors)
|
||||
}
|
||||
} else {
|
||||
if cors != "" {
|
||||
t.Errorf("Expected no CORS header, got %q", cors)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxBodySize(t *testing.T) {
|
||||
maxSize := int64(100)
|
||||
handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentLength int64
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "within limit",
|
||||
contentLength: 50,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "at limit",
|
||||
contentLength: 100,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "exceeds limit",
|
||||
contentLength: 150,
|
||||
expectedStatus: http.StatusRequestEntityTooLarge,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/test", nil)
|
||||
req.ContentLength = tt.contentLength
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenAuth(t *testing.T) {
|
||||
t.Run("disabled auth allows all requests", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(false)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with disabled auth, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled auth requires token", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(true)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Without token
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected Unauthorized without token, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// With correct token in X-Auth-Token header
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Auth-Token", ta.Token())
|
||||
rr = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with correct token, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// With correct token in Authorization header
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+ta.Token())
|
||||
rr = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK with Bearer token, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exempt paths skip auth", func(t *testing.T) {
|
||||
ta, err := NewTokenAuth(true)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenAuth() error = %v", err)
|
||||
}
|
||||
|
||||
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
exemptPaths := []string{"/health", "/api/health", "/api/ready"}
|
||||
for _, path := range exemptPaths {
|
||||
req := httptest.NewRequest("GET", path, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpensiveOperationLimiter(t *testing.T) {
|
||||
limiter := NewExpensiveOperationLimiter()
|
||||
|
||||
// First rebuild should be allowed
|
||||
if !limiter.CanRebuild() {
|
||||
t.Error("First rebuild should be allowed")
|
||||
}
|
||||
|
||||
// Immediate second rebuild should be blocked
|
||||
if limiter.CanRebuild() {
|
||||
t.Error("Immediate second rebuild should be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request ID is in context
|
||||
id := GetRequestID(r.Context())
|
||||
if id == "" {
|
||||
t.Error("Request ID should be set in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("generates new request ID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-ID") == "" {
|
||||
t.Error("X-Request-ID header should be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses existing request ID", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-ID", "test-id-12345")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Header().Get("X-Request-ID") != "test-id-12345" {
|
||||
t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequireJSONContentType(t *testing.T) {
|
||||
handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
contentType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "GET request without content-type",
|
||||
method: "GET",
|
||||
contentType: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with application/json",
|
||||
method: "POST",
|
||||
contentType: "application/json",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with application/json; charset=utf-8",
|
||||
method: "POST",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST without content-type (empty body)",
|
||||
method: "POST",
|
||||
contentType: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "POST with text/plain rejected",
|
||||
method: "POST",
|
||||
contentType: "text/plain",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
name: "PUT with application/xml rejected",
|
||||
method: "PUT",
|
||||
contentType: "application/xml",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
{
|
||||
name: "PATCH with form-urlencoded rejected",
|
||||
method: "PATCH",
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
expectedStatus: http.StatusUnsupportedMediaType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, "/test", nil)
|
||||
if tt.contentType != "" {
|
||||
req.Header.Set("Content-Type", tt.contentType)
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProjectName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
project string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "empty project allowed",
|
||||
project: "",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "simple project name",
|
||||
project: "my-project",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with path",
|
||||
project: "org/my-project",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with underscore",
|
||||
project: "my_project_v2",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "project with dot",
|
||||
project: "my.project.name",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attack",
|
||||
project: "../../../etc/passwd",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "hidden path traversal",
|
||||
project: "project/../../secret",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "shell injection attempt",
|
||||
project: "project; rm -rf /",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "backtick injection",
|
||||
project: "project`whoami`",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
project: "project$HOME",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "too long project name",
|
||||
project: string(make([]byte, 501)),
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateProjectName(tt.project)
|
||||
if tt.wantError && err == nil {
|
||||
t.Errorf("Expected error for project %q, got nil", tt.project)
|
||||
}
|
||||
if !tt.wantError && err != nil {
|
||||
t.Errorf("Unexpected error for project %q: %v", tt.project, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkOperationLimiter(t *testing.T) {
|
||||
limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing
|
||||
|
||||
// First operation should be allowed
|
||||
if !limiter.CanExecute() {
|
||||
t.Error("First bulk operation should be allowed")
|
||||
}
|
||||
|
||||
// Immediate second operation should be blocked
|
||||
if limiter.CanExecute() {
|
||||
t.Error("Immediate second bulk operation should be blocked")
|
||||
}
|
||||
|
||||
// Check cooldown remaining
|
||||
remaining := limiter.CooldownRemaining()
|
||||
if remaining <= 0 || remaining > 1 {
|
||||
t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining)
|
||||
}
|
||||
|
||||
// Check time since last op
|
||||
since := limiter.TimeSinceLastOp()
|
||||
if since < 0 || since > 1 {
|
||||
t.Errorf("Expected time since last op between 0-1 seconds, got %d", since)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_CSP(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check CSP header is set
|
||||
csp := rr.Header().Get("Content-Security-Policy")
|
||||
if csp == "" {
|
||||
t.Error("Content-Security-Policy header should be set")
|
||||
}
|
||||
if csp != "default-src 'self'" {
|
||||
t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp)
|
||||
}
|
||||
|
||||
// Check Permissions-Policy header
|
||||
pp := rr.Header().Get("Permissions-Policy")
|
||||
if pp == "" {
|
||||
t.Error("Permissions-Policy header should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_Preflight(t *testing.T) {
|
||||
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
||||
req.Header.Set("Origin", "http://localhost:3000")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// OPTIONS should return 204 No Content
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code)
|
||||
}
|
||||
|
||||
// CORS headers should be set for allowed origin
|
||||
if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
|
||||
t.Errorf("CORS origin should be set for allowed origin")
|
||||
}
|
||||
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Error("Access-Control-Allow-Methods should be set")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
// Package worker provides the main worker service for claude-mnemonic.
|
||||
package worker
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter.
|
||||
type RateLimiter struct {
|
||||
lastUpdate time.Time
|
||||
rate float64
|
||||
burst int
|
||||
tokens float64
|
||||
requests int64
|
||||
rejected int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// LastUpdateTime returns the last update time.
|
||||
// Thread-safe - acquires the limiter's lock.
|
||||
func (rl *RateLimiter) LastUpdateTime() time.Time {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
return rl.lastUpdate
|
||||
}
|
||||
|
||||
// lastUpdateTimeUnlocked returns the last update time without locking.
|
||||
// Caller must hold rl.mu.
|
||||
func (rl *RateLimiter) lastUpdateTimeUnlocked() time.Time {
|
||||
return rl.lastUpdate
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
// rate is the number of requests per second to allow.
|
||||
// burst is the maximum burst of requests to allow.
|
||||
func NewRateLimiter(rate float64, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
tokens: float64(burst),
|
||||
lastUpdate: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed.
|
||||
// Returns true if the request is allowed, false if rate limited.
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.requests++
|
||||
|
||||
// Calculate tokens added since last update
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastUpdate).Seconds()
|
||||
rl.tokens += elapsed * rl.rate
|
||||
if rl.tokens > float64(rl.burst) {
|
||||
rl.tokens = float64(rl.burst)
|
||||
}
|
||||
rl.lastUpdate = now
|
||||
|
||||
// Check if we have a token available
|
||||
if rl.tokens >= 1 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
rl.rejected++
|
||||
return false
|
||||
}
|
||||
|
||||
// Stats returns rate limiter statistics.
|
||||
func (rl *RateLimiter) Stats() map[string]any {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
return map[string]any{
|
||||
"rate": rl.rate,
|
||||
"burst": rl.burst,
|
||||
"current_tokens": rl.tokens,
|
||||
"total_requests": rl.requests,
|
||||
"rejected": rl.rejected,
|
||||
"rejection_rate": float64(rl.rejected) / max(float64(rl.requests), 1),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitMiddleware creates middleware that applies rate limiting.
|
||||
// Uses a shared rate limiter for all requests.
|
||||
func RateLimitMiddleware(limiter *RateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PerClientRateLimiter implements per-client rate limiting.
|
||||
type PerClientRateLimiter struct {
|
||||
lastCleanup time.Time
|
||||
clients map[string]*RateLimiter
|
||||
rate float64
|
||||
burst int
|
||||
cleanupInterval time.Duration
|
||||
maxIdleTime time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewPerClientRateLimiter creates a new per-client rate limiter.
|
||||
func NewPerClientRateLimiter(rate float64, burst int) *PerClientRateLimiter {
|
||||
return &PerClientRateLimiter{
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
clients: make(map[string]*RateLimiter),
|
||||
cleanupInterval: 5 * time.Minute,
|
||||
maxIdleTime: 10 * time.Minute,
|
||||
lastCleanup: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter returns a rate limiter for the given client key.
|
||||
func (pcrl *PerClientRateLimiter) getLimiter(key string) *RateLimiter {
|
||||
pcrl.mu.Lock()
|
||||
defer pcrl.mu.Unlock()
|
||||
|
||||
// Periodic cleanup of idle clients
|
||||
if time.Since(pcrl.lastCleanup) > pcrl.cleanupInterval {
|
||||
pcrl.cleanupLocked()
|
||||
}
|
||||
|
||||
limiter, exists := pcrl.clients[key]
|
||||
if !exists {
|
||||
limiter = NewRateLimiter(pcrl.rate, pcrl.burst)
|
||||
pcrl.clients[key] = limiter
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupLocked removes idle limiters. Must be called with lock held.
|
||||
// Uses consistent lock ordering: always acquire limiter.mu while holding pcrl.mu.
|
||||
// This is safe because the limiter.mu critical section is brief (just reading lastUpdate).
|
||||
func (pcrl *PerClientRateLimiter) cleanupLocked() {
|
||||
now := time.Now()
|
||||
keysToDelete := make([]string, 0)
|
||||
|
||||
// Check each limiter while holding pcrl.mu
|
||||
// We briefly acquire limiter.mu but the critical section is minimal
|
||||
for key, limiter := range pcrl.clients {
|
||||
limiter.mu.Lock()
|
||||
lastUpdate := limiter.lastUpdateTimeUnlocked()
|
||||
limiter.mu.Unlock()
|
||||
|
||||
if now.Sub(lastUpdate) > pcrl.maxIdleTime {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete collected keys
|
||||
for _, key := range keysToDelete {
|
||||
delete(pcrl.clients, key)
|
||||
}
|
||||
pcrl.lastCleanup = now
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given client should be allowed.
|
||||
func (pcrl *PerClientRateLimiter) Allow(clientKey string) bool {
|
||||
return pcrl.getLimiter(clientKey).Allow()
|
||||
}
|
||||
|
||||
// Stats returns aggregate statistics.
|
||||
// Uses two-phase approach to avoid nested lock acquisition.
|
||||
func (pcrl *PerClientRateLimiter) Stats() map[string]any {
|
||||
// Phase 1: Collect limiters under pcrl.mu
|
||||
pcrl.mu.Lock()
|
||||
rate := pcrl.rate
|
||||
burst := pcrl.burst
|
||||
activeClients := len(pcrl.clients)
|
||||
limiters := make([]*RateLimiter, 0, activeClients)
|
||||
for _, limiter := range pcrl.clients {
|
||||
limiters = append(limiters, limiter)
|
||||
}
|
||||
pcrl.mu.Unlock()
|
||||
|
||||
// Phase 2: Collect stats from each limiter (only acquiring limiter.mu, not pcrl.mu)
|
||||
var totalRequests, totalRejected int64
|
||||
for _, limiter := range limiters {
|
||||
limiter.mu.Lock()
|
||||
totalRequests += limiter.requests
|
||||
totalRejected += limiter.rejected
|
||||
limiter.mu.Unlock()
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"rate": rate,
|
||||
"burst": burst,
|
||||
"active_clients": activeClients,
|
||||
"total_requests": totalRequests,
|
||||
"total_rejected": totalRejected,
|
||||
}
|
||||
}
|
||||
|
||||
// PerClientRateLimitMiddleware creates middleware that applies per-client rate limiting.
|
||||
// Uses X-Forwarded-For or RemoteAddr to identify clients.
|
||||
func PerClientRateLimitMiddleware(limiter *PerClientRateLimiter) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get client identifier (prefer X-Real-IP from RealIP middleware)
|
||||
clientKey := r.RemoteAddr
|
||||
if xff := r.Header.Get("X-Real-IP"); xff != "" {
|
||||
clientKey = xff
|
||||
}
|
||||
|
||||
if !limiter.Allow(clientKey) {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,15 @@ package sdk
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -20,8 +24,178 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// CircuitBreaker implements a simple circuit breaker pattern for CLI calls.
|
||||
type CircuitBreaker struct {
|
||||
failures int64 // Current failure count
|
||||
lastFailure int64 // Unix timestamp of last failure
|
||||
threshold int64 // Number of failures before opening
|
||||
resetTimeout int64 // Seconds to wait before trying again
|
||||
state int32 // 0=closed, 1=open, 2=half-open
|
||||
}
|
||||
|
||||
const (
|
||||
circuitClosed int32 = 0
|
||||
circuitOpen int32 = 1
|
||||
circuitHalfOpen int32 = 2
|
||||
)
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker.
|
||||
func NewCircuitBreaker(threshold int64, resetTimeout int64) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
threshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed through.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == circuitClosed {
|
||||
return true
|
||||
}
|
||||
|
||||
if state == circuitOpen {
|
||||
// Check if reset timeout has passed
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
if time.Now().Unix()-lastFail > cb.resetTimeout {
|
||||
// Transition to half-open
|
||||
atomic.CompareAndSwapInt32(&cb.state, circuitOpen, circuitHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Half-open: allow one request through
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful call.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
atomic.StoreInt32(&cb.state, circuitClosed)
|
||||
}
|
||||
|
||||
// RecordFailure records a failed call.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
failures := atomic.AddInt64(&cb.failures, 1)
|
||||
atomic.StoreInt64(&cb.lastFailure, time.Now().Unix())
|
||||
|
||||
if failures >= cb.threshold {
|
||||
atomic.StoreInt32(&cb.state, circuitOpen)
|
||||
log.Warn().Int64("failures", failures).Msg("Circuit breaker opened - Claude CLI calls temporarily disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current state as a string.
|
||||
func (cb *CircuitBreaker) State() string {
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case circuitOpen:
|
||||
return "open"
|
||||
case circuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "closed"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics contains metrics about the circuit breaker state.
|
||||
type CircuitBreakerMetrics struct {
|
||||
State string `json:"state"`
|
||||
Failures int64 `json:"failures"`
|
||||
Threshold int64 `json:"threshold"`
|
||||
ResetTimeoutSecs int64 `json:"reset_timeout_secs"`
|
||||
LastFailureUnix int64 `json:"last_failure_unix,omitempty"`
|
||||
SecondsUntilReset int64 `json:"seconds_until_reset,omitempty"`
|
||||
}
|
||||
|
||||
// Metrics returns the current metrics of the circuit breaker.
|
||||
func (cb *CircuitBreaker) Metrics() CircuitBreakerMetrics {
|
||||
failures := atomic.LoadInt64(&cb.failures)
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailure)
|
||||
state := cb.State()
|
||||
|
||||
metrics := CircuitBreakerMetrics{
|
||||
State: state,
|
||||
Failures: failures,
|
||||
Threshold: cb.threshold,
|
||||
ResetTimeoutSecs: cb.resetTimeout,
|
||||
}
|
||||
|
||||
if lastFail > 0 {
|
||||
metrics.LastFailureUnix = lastFail
|
||||
if state == "open" {
|
||||
remaining := cb.resetTimeout - (time.Now().Unix() - lastFail)
|
||||
if remaining > 0 {
|
||||
metrics.SecondsUntilReset = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RequestDeduplicator tracks recent requests to prevent duplicates.
|
||||
type RequestDeduplicator struct {
|
||||
seen map[string]int64 // hash -> timestamp
|
||||
mu sync.RWMutex
|
||||
ttlSecs int64
|
||||
maxSize int
|
||||
}
|
||||
|
||||
// NewRequestDeduplicator creates a new deduplicator.
|
||||
func NewRequestDeduplicator(ttlSecs int64, maxSize int) *RequestDeduplicator {
|
||||
return &RequestDeduplicator{
|
||||
seen: make(map[string]int64),
|
||||
ttlSecs: ttlSecs,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDuplicate checks if a request hash was seen recently.
|
||||
func (d *RequestDeduplicator) IsDuplicate(hash string) bool {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.RLock()
|
||||
ts, exists := d.seen[hash]
|
||||
d.mu.RUnlock()
|
||||
|
||||
if exists && now-ts < d.ttlSecs {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Record marks a request hash as seen.
|
||||
func (d *RequestDeduplicator) Record(hash string) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// Evict old entries if at capacity
|
||||
if len(d.seen) >= d.maxSize {
|
||||
threshold := now - d.ttlSecs
|
||||
for k, ts := range d.seen {
|
||||
if ts < threshold {
|
||||
delete(d.seen, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.seen[hash] = now
|
||||
}
|
||||
|
||||
// hashRequest creates a hash of a request for deduplication.
|
||||
func hashRequest(toolName, input, output string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(toolName))
|
||||
h.Write([]byte(input))
|
||||
h.Write([]byte(output[:min(len(output), 1000)])) // Only hash first 1000 chars of output
|
||||
return hex.EncodeToString(h.Sum(nil))[:16] // Short hash is sufficient
|
||||
}
|
||||
|
||||
// BroadcastFunc is a callback for broadcasting events to SSE clients.
|
||||
type BroadcastFunc func(event map[string]interface{})
|
||||
type BroadcastFunc func(event map[string]any)
|
||||
|
||||
// SyncObservationFunc is a callback for syncing observations to vector DB.
|
||||
type SyncObservationFunc func(obs *models.Observation)
|
||||
@@ -29,16 +203,26 @@ type SyncObservationFunc func(obs *models.Observation)
|
||||
// SyncSummaryFunc is a callback for syncing summaries to vector DB.
|
||||
type SyncSummaryFunc func(summary *models.SessionSummary)
|
||||
|
||||
// MaxVectorSyncWorkers is the maximum number of concurrent vector sync operations.
|
||||
// This prevents unbounded goroutine spawning during high-volume observation ingestion.
|
||||
const MaxVectorSyncWorkers = 8
|
||||
|
||||
// Processor handles SDK agent processing of observations and summaries using Claude Code CLI.
|
||||
// Field order optimized for memory alignment (fieldalignment).
|
||||
type Processor struct {
|
||||
observationStore *gorm.ObservationStore
|
||||
summaryStore *gorm.SummaryStore
|
||||
broadcastFunc BroadcastFunc
|
||||
syncObservationFunc SyncObservationFunc
|
||||
syncSummaryFunc SyncSummaryFunc
|
||||
circuitBreaker *CircuitBreaker
|
||||
deduplicator *RequestDeduplicator
|
||||
vectorSyncChan chan *models.Observation
|
||||
vectorSyncDone chan struct{}
|
||||
sem chan struct{}
|
||||
claudePath string
|
||||
model string
|
||||
vectorSyncWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// SetBroadcastFunc sets the broadcast callback for SSE events.
|
||||
@@ -57,7 +241,7 @@ func (p *Processor) SetSyncSummaryFunc(fn SyncSummaryFunc) {
|
||||
}
|
||||
|
||||
// broadcast sends an event via the broadcast callback if set.
|
||||
func (p *Processor) broadcast(event map[string]interface{}) {
|
||||
func (p *Processor) broadcast(event map[string]any) {
|
||||
if p.broadcastFunc != nil {
|
||||
p.broadcastFunc(event)
|
||||
}
|
||||
@@ -93,9 +277,65 @@ func NewProcessor(observationStore *gorm.ObservationStore, summaryStore *gorm.Su
|
||||
observationStore: observationStore,
|
||||
summaryStore: summaryStore,
|
||||
sem: make(chan struct{}, MaxConcurrentCLICalls),
|
||||
circuitBreaker: NewCircuitBreaker(5, 60), // Open after 5 failures, reset after 60s
|
||||
deduplicator: NewRequestDeduplicator(300, 1000), // 5-minute TTL, 1000 max entries
|
||||
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2), // Buffered channel
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartVectorSyncWorkers starts the bounded worker pool for vector sync operations.
|
||||
// Call this after setting the sync function via SetSyncObservationFunc.
|
||||
func (p *Processor) StartVectorSyncWorkers() {
|
||||
for i := 0; i < MaxVectorSyncWorkers; i++ {
|
||||
p.vectorSyncWg.Add(1)
|
||||
go p.vectorSyncWorker()
|
||||
}
|
||||
log.Info().Int("workers", MaxVectorSyncWorkers).Msg("Vector sync worker pool started")
|
||||
}
|
||||
|
||||
// StopVectorSyncWorkers gracefully stops the worker pool.
|
||||
func (p *Processor) StopVectorSyncWorkers() {
|
||||
close(p.vectorSyncDone)
|
||||
p.vectorSyncWg.Wait()
|
||||
log.Info().Msg("Vector sync worker pool stopped")
|
||||
}
|
||||
|
||||
// vectorSyncWorker is a worker goroutine that processes vector sync requests.
|
||||
func (p *Processor) vectorSyncWorker() {
|
||||
defer p.vectorSyncWg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-p.vectorSyncDone:
|
||||
// Drain remaining items before exiting
|
||||
for {
|
||||
select {
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case obs := <-p.vectorSyncChan:
|
||||
if p.syncObservationFunc != nil {
|
||||
p.syncObservationFunc(obs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerState returns the current state of the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerState() string {
|
||||
return p.circuitBreaker.State()
|
||||
}
|
||||
|
||||
// CircuitBreakerMetrics returns detailed metrics about the circuit breaker.
|
||||
func (p *Processor) CircuitBreakerMetrics() CircuitBreakerMetrics {
|
||||
return p.circuitBreaker.Metrics()
|
||||
}
|
||||
|
||||
// IsAvailable checks if the Claude CLI is available for processing.
|
||||
func (p *Processor) IsAvailable() bool {
|
||||
_, err := os.Stat(p.claudePath)
|
||||
@@ -103,7 +343,7 @@ func (p *Processor) IsAvailable() bool {
|
||||
}
|
||||
|
||||
// ProcessObservation processes a single tool observation and extracts insights.
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse interface{}, promptNumber int, cwd string) error {
|
||||
func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, project string, toolName string, toolInput, toolResponse any, promptNumber int, cwd string) error {
|
||||
// Skip certain tools that aren't worth processing
|
||||
if shouldSkipTool(toolName) {
|
||||
log.Info().Str("tool", toolName).Msg("Skipping tool (not interesting for memory)")
|
||||
@@ -120,11 +360,23 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for duplicate request within TTL window
|
||||
reqHash := hashRequest(toolName, inputStr, outputStr)
|
||||
if p.deduplicator.IsDuplicate(reqHash) {
|
||||
log.Debug().Str("tool", toolName).Msg("Skipping duplicate request (dedup)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check circuit breaker before making CLI call
|
||||
if !p.circuitBreaker.Allow() {
|
||||
log.Warn().Str("tool", toolName).Msg("Circuit breaker open - skipping CLI call")
|
||||
return fmt.Errorf("circuit breaker open")
|
||||
}
|
||||
|
||||
log.Info().Str("tool", toolName).Msg("Processing tool execution with Claude CLI")
|
||||
|
||||
// Note: Removed the "file already has observations" check
|
||||
// Each tool execution can produce unique insights even for the same file
|
||||
// Similarity-based deduplication will handle true duplicates
|
||||
// Record this request to prevent duplicates
|
||||
p.deduplicator.Record(reqHash)
|
||||
|
||||
// Build the prompt
|
||||
exec := ToolExecution{
|
||||
@@ -146,9 +398,11 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
// Call Claude Code CLI
|
||||
response, err := p.callClaudeCLI(ctx, prompt)
|
||||
if err != nil {
|
||||
p.circuitBreaker.RecordFailure()
|
||||
log.Error().Err(err).Str("tool", toolName).Msg("Failed to call Claude CLI for observation")
|
||||
return err
|
||||
}
|
||||
p.circuitBreaker.RecordSuccess()
|
||||
|
||||
// Parse observations from response
|
||||
observations := ParseObservations(response, sdkSessionID)
|
||||
@@ -199,16 +453,26 @@ func (p *Processor) ProcessObservation(ctx context.Context, sdkSessionID, projec
|
||||
Int("trackedFiles", len(obs.FileMtimes)).
|
||||
Msg("Observation stored")
|
||||
|
||||
// Sync to vector DB if callback is set
|
||||
if p.syncObservationFunc != nil {
|
||||
// Sync to vector DB via bounded worker pool (non-blocking to reduce latency)
|
||||
if p.syncObservationFunc != nil && p.vectorSyncChan != nil {
|
||||
fullObs := models.NewObservation(sdkSessionID, project, obs, promptNumber, 0)
|
||||
fullObs.ID = id
|
||||
fullObs.CreatedAtEpoch = createdAtEpoch
|
||||
p.syncObservationFunc(fullObs)
|
||||
// Non-blocking send to worker pool - drops if channel is full
|
||||
select {
|
||||
case p.vectorSyncChan <- fullObs:
|
||||
// Sent to worker pool
|
||||
default:
|
||||
// Channel full, fall back to direct sync in goroutine (bounded by channel buffer)
|
||||
log.Debug().Int64("obs_id", id).Msg("Vector sync channel full, using fallback goroutine")
|
||||
go func(obsToSync *models.Observation) {
|
||||
p.syncObservationFunc(obsToSync)
|
||||
}(fullObs)
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast new observation event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "observation",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -310,7 +574,7 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
}
|
||||
|
||||
// Broadcast new summary event for dashboard refresh
|
||||
p.broadcast(map[string]interface{}{
|
||||
p.broadcast(map[string]any{
|
||||
"type": "summary",
|
||||
"action": "created",
|
||||
"id": id,
|
||||
@@ -320,8 +584,31 @@ func (p *Processor) ProcessSummary(ctx context.Context, sessionDBID int64, sdkSe
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaxPromptSize is the maximum size of a prompt that can be passed to the Claude CLI.
|
||||
// This prevents resource exhaustion from extremely large prompts.
|
||||
const MaxPromptSize = 100 * 1024 // 100KB
|
||||
|
||||
// sanitizePrompt removes null bytes and control characters from a prompt.
|
||||
// Keeps newlines, tabs, and carriage returns as they're valid in prompts.
|
||||
func sanitizePrompt(s string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
// Keep printable ASCII, extended Unicode, and common whitespace
|
||||
if r >= 32 || r == '\n' || r == '\t' || r == '\r' {
|
||||
return r
|
||||
}
|
||||
// Remove null bytes and other control characters
|
||||
return -1
|
||||
}, s)
|
||||
}
|
||||
|
||||
// callClaudeCLI calls the Claude Code CLI with the given prompt.
|
||||
func (p *Processor) callClaudeCLI(ctx context.Context, prompt string) (string, error) {
|
||||
// Validate and sanitize prompt
|
||||
if len(prompt) > MaxPromptSize {
|
||||
return "", fmt.Errorf("prompt exceeds maximum size of %d bytes", MaxPromptSize)
|
||||
}
|
||||
prompt = sanitizePrompt(prompt)
|
||||
|
||||
// Build the full prompt with system instructions
|
||||
fullPrompt := systemPrompt + "\n\n" + prompt
|
||||
|
||||
@@ -418,8 +705,11 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
// Pre-compute lowercase strings once to avoid repeated allocations
|
||||
lowerOutput := strings.ToLower(outputStr)
|
||||
lowerInput := strings.ToLower(inputStr)
|
||||
|
||||
// Skip if output indicates an error or empty result
|
||||
trivialOutputs := []string{
|
||||
"no matches found",
|
||||
"file not found",
|
||||
@@ -443,13 +733,13 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
// Skip reading config files that rarely contain project-specific insights
|
||||
boringFiles := []string{
|
||||
"package-lock.json", "yarn.lock", "pnpm-lock.yaml",
|
||||
"go.sum", "Cargo.lock", "Gemfile.lock", "poetry.lock",
|
||||
"go.sum", "cargo.lock", "gemfile.lock", "poetry.lock",
|
||||
".gitignore", ".dockerignore", ".eslintignore",
|
||||
"tsconfig.json", "jsconfig.json", "vite.config",
|
||||
"tailwind.config", "postcss.config",
|
||||
}
|
||||
for _, boring := range boringFiles {
|
||||
if strings.Contains(inputStr, boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -461,14 +751,14 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
case "Bash":
|
||||
// Skip simple status commands
|
||||
// Skip simple status commands (use pre-computed lowerInput)
|
||||
boringCommands := []string{
|
||||
"git status", "git diff", "git log", "git branch",
|
||||
"ls ", "pwd", "echo ", "cat ", "which ", "type ",
|
||||
"npm list", "npm outdated", "npm audit",
|
||||
}
|
||||
for _, boring := range boringCommands {
|
||||
if strings.Contains(strings.ToLower(inputStr), boring) {
|
||||
if strings.Contains(lowerInput, boring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -478,7 +768,7 @@ func shouldSkipTrivialOperation(toolName, inputStr, outputStr string) bool {
|
||||
}
|
||||
|
||||
// toJSONString converts an interface to a JSON string.
|
||||
func toJSONString(v interface{}) string {
|
||||
func toJSONString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -492,38 +782,132 @@ func toJSONString(v interface{}) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// safeResolvePath resolves a path relative to cwd and validates it doesn't escape the cwd directory.
|
||||
// Returns the resolved absolute path and true if valid, or empty string and false if path traversal detected.
|
||||
// This function is a security sanitizer for path traversal attacks.
|
||||
func safeResolvePath(path, cwd string) (string, bool) {
|
||||
// Clean the input path to normalize any .. or . components
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Reject paths that explicitly contain parent directory traversal after cleaning
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if filepath.IsAbs(cleanPath) {
|
||||
// For absolute paths, verify they're within cwd if cwd is specified
|
||||
if cwd != "" {
|
||||
cleanCwd := filepath.Clean(cwd)
|
||||
if !strings.HasPrefix(cleanPath, cleanCwd+string(filepath.Separator)) && cleanPath != cleanCwd {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
return cleanPath, true
|
||||
}
|
||||
|
||||
if cwd == "" {
|
||||
return cleanPath, true
|
||||
}
|
||||
|
||||
// Clean the cwd first
|
||||
cleanCwd := filepath.Clean(cwd)
|
||||
|
||||
// Join and clean the path
|
||||
absPath := filepath.Join(cleanCwd, cleanPath)
|
||||
|
||||
// Use filepath.Rel to verify the path is actually within cwd
|
||||
// If Rel returns a path starting with "..", it escapes the base
|
||||
rel, err := filepath.Rel(cleanCwd, absPath)
|
||||
if err != nil || strings.HasPrefix(rel, "..") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
// captureFileMtimes captures current modification times for tracked files.
|
||||
// Returns a map of absolute file paths to their mtime in epoch milliseconds.
|
||||
// For large file lists (>10 files), uses parallel stat calls for better performance.
|
||||
func captureFileMtimes(filesRead, filesModified []string, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64)
|
||||
// Combine all unique file paths
|
||||
allPaths := make(map[string]struct{}, len(filesRead)+len(filesModified))
|
||||
for _, path := range filesRead {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
for _, path := range filesModified {
|
||||
allPaths[path] = struct{}{}
|
||||
}
|
||||
|
||||
// Helper to get mtime for a file path
|
||||
getMtime := func(path string) (int64, bool) {
|
||||
// Resolve relative paths against cwd
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
// For small lists, use sequential processing (goroutine overhead not worth it)
|
||||
if len(allPaths) <= 10 {
|
||||
return captureFileMtimesSequential(allPaths, cwd)
|
||||
}
|
||||
|
||||
// For larger lists, parallelize with bounded concurrency
|
||||
return captureFileMtimesParallel(allPaths, cwd)
|
||||
}
|
||||
|
||||
// captureFileMtimesSequential captures mtimes sequentially (efficient for small lists).
|
||||
func captureFileMtimesSequential(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
|
||||
for path := range paths {
|
||||
absPath, ok := safeResolvePath(path, cwd)
|
||||
if !ok {
|
||||
// Skip paths that attempt directory traversal
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return info.ModTime().UnixMilli(), true
|
||||
}
|
||||
|
||||
// Capture mtimes for all read files
|
||||
for _, path := range filesRead {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
if err == nil {
|
||||
mtimes[path] = info.ModTime().UnixMilli()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture mtimes for all modified files
|
||||
for _, path := range filesModified {
|
||||
if mtime, ok := getMtime(path); ok {
|
||||
mtimes[path] = mtime
|
||||
}
|
||||
return mtimes
|
||||
}
|
||||
|
||||
// captureFileMtimesParallel captures mtimes in parallel with bounded concurrency.
|
||||
func captureFileMtimesParallel(paths map[string]struct{}, cwd string) map[string]int64 {
|
||||
type mtimeResult struct {
|
||||
path string
|
||||
mtime int64
|
||||
}
|
||||
|
||||
results := make(chan mtimeResult, len(paths))
|
||||
sem := make(chan struct{}, 8) // Limit to 8 concurrent stat calls
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for path := range paths {
|
||||
wg.Add(1)
|
||||
go func(p string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // Acquire
|
||||
defer func() { <-sem }() // Release
|
||||
|
||||
absPath, ok := safeResolvePath(p, cwd)
|
||||
if !ok {
|
||||
// Skip paths that attempt directory traversal
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err == nil {
|
||||
results <- mtimeResult{path: p, mtime: info.ModTime().UnixMilli()}
|
||||
}
|
||||
}(path)
|
||||
}
|
||||
|
||||
// Close results channel when all goroutines complete
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// Collect results
|
||||
mtimes := make(map[string]int64, len(paths))
|
||||
for res := range results {
|
||||
mtimes[res.path] = res.mtime
|
||||
}
|
||||
|
||||
return mtimes
|
||||
@@ -538,12 +922,13 @@ func GetFileMtimes(paths []string, cwd string) map[string]int64 {
|
||||
// GetFileContent reads file content for verification purposes.
|
||||
// Returns content and ok status.
|
||||
func GetFileContent(path, cwd string) (string, bool) {
|
||||
absPath := path
|
||||
if !filepath.IsAbs(path) && cwd != "" {
|
||||
absPath = filepath.Join(cwd, path)
|
||||
absPath, ok := safeResolvePath(path, cwd)
|
||||
if !ok {
|
||||
// Reject paths that attempt directory traversal
|
||||
return "", false
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath) // #nosec G304 -- intentional file read for verification
|
||||
content, err := os.ReadFile(absPath) // #nosec G304 -- path validated by safeResolvePath
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -974,3 +974,207 @@ func TestSyncSummaryFuncType(t *testing.T) {
|
||||
}
|
||||
assert.NotNil(t, fn)
|
||||
}
|
||||
|
||||
// TestSanitizePrompt tests prompt sanitization for CLI safety.
|
||||
func TestSanitizePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal text",
|
||||
input: "Hello, world!",
|
||||
expected: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "text with newlines",
|
||||
input: "Line 1\nLine 2\nLine 3",
|
||||
expected: "Line 1\nLine 2\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "text with tabs",
|
||||
input: "Key:\tValue",
|
||||
expected: "Key:\tValue",
|
||||
},
|
||||
{
|
||||
name: "text with carriage return",
|
||||
input: "Line 1\r\nLine 2",
|
||||
expected: "Line 1\r\nLine 2",
|
||||
},
|
||||
{
|
||||
name: "text with null bytes",
|
||||
input: "Hello\x00World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with control characters",
|
||||
input: "Hello\x01\x02\x03World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with bell character",
|
||||
input: "Hello\x07World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with backspace",
|
||||
input: "Hello\x08World",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with form feed",
|
||||
input: "Hello\x0cWorld",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "text with escape",
|
||||
input: "Hello\x1bWorld",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
{
|
||||
name: "unicode text",
|
||||
input: "Hello 世界 🌍",
|
||||
expected: "Hello 世界 🌍",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "only control characters",
|
||||
input: "\x00\x01\x02\x03",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizePrompt(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxPromptSize tests that MaxPromptSize is reasonable.
|
||||
func TestMaxPromptSize(t *testing.T) {
|
||||
assert.Equal(t, 100*1024, MaxPromptSize)
|
||||
}
|
||||
|
||||
// BenchmarkSanitizePrompt benchmarks the sanitize function.
|
||||
func BenchmarkSanitizePrompt(b *testing.B) {
|
||||
prompt := "Analyze the following code:\n```go\nfunc main() {\n\tfmt.Println(\"Hello, World!\")\n}\n```\n\nPlease identify any issues."
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sanitizePrompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizePromptWithControlChars benchmarks sanitization with control characters.
|
||||
func BenchmarkSanitizePromptWithControlChars(b *testing.B) {
|
||||
prompt := "Hello\x00World\x01Test\x02Data\x03End"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
sanitizePrompt(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSafeResolvePath tests the path traversal protection.
|
||||
func TestSafeResolvePath(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
cwd string
|
||||
wantPath string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "simple relative path",
|
||||
path: "file.txt",
|
||||
cwd: tmpDir,
|
||||
wantOk: true,
|
||||
wantPath: filepath.Join(tmpDir, "file.txt"),
|
||||
},
|
||||
{
|
||||
name: "nested relative path",
|
||||
path: "subdir/file.txt",
|
||||
cwd: tmpDir,
|
||||
wantOk: true,
|
||||
wantPath: filepath.Join(tmpDir, "subdir", "file.txt"),
|
||||
},
|
||||
{
|
||||
name: "path traversal with ..",
|
||||
path: "../etc/passwd",
|
||||
cwd: tmpDir,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal with multiple ..",
|
||||
path: "../../etc/passwd",
|
||||
cwd: tmpDir,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal hidden in middle",
|
||||
path: "subdir/../../../etc/passwd",
|
||||
cwd: tmpDir,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "just parent directory",
|
||||
path: "..",
|
||||
cwd: tmpDir,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path without cwd",
|
||||
path: "/some/absolute/path",
|
||||
cwd: "",
|
||||
wantOk: true,
|
||||
wantPath: "/some/absolute/path",
|
||||
},
|
||||
{
|
||||
name: "relative path without cwd",
|
||||
path: "relative/path",
|
||||
cwd: "",
|
||||
wantOk: true,
|
||||
wantPath: "relative/path",
|
||||
},
|
||||
{
|
||||
name: "current directory reference",
|
||||
path: "./file.txt",
|
||||
cwd: tmpDir,
|
||||
wantOk: true,
|
||||
wantPath: filepath.Join(tmpDir, "file.txt"),
|
||||
},
|
||||
{
|
||||
name: "absolute path outside cwd",
|
||||
path: "/etc/passwd",
|
||||
cwd: tmpDir,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path inside cwd",
|
||||
path: filepath.Join(tmpDir, "inside.txt"),
|
||||
cwd: tmpDir,
|
||||
wantOk: true,
|
||||
wantPath: filepath.Join(tmpDir, "inside.txt"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotPath, gotOk := safeResolvePath(tt.path, tt.cwd)
|
||||
assert.Equal(t, tt.wantOk, gotOk, "ok status mismatch")
|
||||
if tt.wantPath != "" && gotOk {
|
||||
assert.Equal(t, tt.wantPath, gotPath, "path mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+678
-443
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,8 @@ import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -13,16 +15,22 @@ var staticFS embed.FS
|
||||
// staticSubFS is the static subdirectory filesystem
|
||||
var staticSubFS fs.FS
|
||||
|
||||
// staticInitErr stores any error from static filesystem initialization
|
||||
var staticInitErr error
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
staticSubFS, err = fs.Sub(staticFS, "static")
|
||||
if err != nil {
|
||||
panic("failed to create sub filesystem: " + err.Error())
|
||||
staticSubFS, staticInitErr = fs.Sub(staticFS, "static")
|
||||
if staticInitErr != nil {
|
||||
log.Warn().Err(staticInitErr).Msg("Static filesystem initialization failed - dashboard will be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
// serveIndex serves the index.html file for the root path
|
||||
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
if staticInitErr != nil {
|
||||
http.Error(w, "Dashboard unavailable: static files not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
content, err := fs.ReadFile(staticSubFS, "index.html")
|
||||
if err != nil {
|
||||
http.Error(w, "Dashboard not found", http.StatusNotFound)
|
||||
@@ -38,6 +46,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// serveAssets serves static assets from the embedded filesystem
|
||||
func serveAssets(w http.ResponseWriter, r *http.Request) {
|
||||
if staticInitErr != nil {
|
||||
http.Error(w, "Assets unavailable: static files not initialized", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
// Strip the /assets/ prefix and serve the file
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user