From 4f4b4ac70f3276627b73547e3689da30b34e3153 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 7 Jan 2026 13:19:58 +0000 Subject: [PATCH] feat(chunking): add AST-aware code chunking for Go, Python, TypeScript MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - [x] Add language-specific chunkers with AST parsing (Go, Python, TypeScript) - [x] Implement chunking manager to dispatch files to appropriate chunkers - [x] Integrate code chunks into vector sync for semantic search - [x] Add tree-sitter dependency for Python/TypeScript parsing - [x] Reorder struct fields for consistency across codebase - [x] Rename error variables to follow Go conventions (err → unmarshalErr, etc.) - [x] Add code chunk metadata to vector documents (language, symbol name, line ranges) - [x] Update worker service to initialize chunking pipeline with all three languages --- cmd/hooks/session-start/main.go | 6 +- cmd/hooks/statusline/main.go | 20 +- cmd/hooks/stop/main.go | 6 +- cmd/hooks/user-prompt/main.go | 4 +- go.mod | 1 + go.sum | 2 + internal/chunking/golang/chunker.go | 285 +++++++++++++++ internal/chunking/golang/chunker_test.go | 214 +++++++++++ internal/chunking/manager.go | 106 ++++++ internal/chunking/manager_test.go | 162 +++++++++ internal/chunking/python/chunker.go | 291 +++++++++++++++ internal/chunking/types.go | 140 +++++++ internal/chunking/typescript/chunker.go | 403 +++++++++++++++++++++ internal/config/config.go | 49 +-- internal/config/config_test.go | 6 +- internal/db/sqlite/conflict.go | 12 +- internal/db/sqlite/helpers_test.go | 12 +- internal/db/sqlite/migrations.go | 8 +- internal/db/sqlite/observation.go | 8 +- internal/db/sqlite/observation_test.go | 4 +- internal/db/sqlite/prompt.go | 8 +- internal/db/sqlite/prompt_test.go | 10 +- internal/db/sqlite/scoring.go | 4 +- internal/db/sqlite/session_test.go | 2 +- internal/db/sqlite/store.go | 4 +- internal/db/sqlite/store_test.go | 14 +- internal/embedding/model.go | 23 +- internal/embedding/service.go | 16 +- internal/mcp/server.go | 20 +- internal/mcp/server_test.go | 28 +- internal/pattern/detector.go | 24 +- internal/pattern/detector_test.go | 14 +- internal/reranking/service.go | 36 +- internal/scoring/recalculator.go | 6 +- internal/scoring/recalculator_test.go | 8 +- internal/search/expansion/expander.go | 18 +- internal/search/expansion/expander_test.go | 24 +- internal/search/integration_test.go | 16 +- internal/search/manager.go | 30 +- internal/search/manager_test.go | 86 ++--- internal/update/update.go | 29 +- internal/vector/sqlitevec/client.go | 16 +- internal/vector/sqlitevec/helpers.go | 7 +- internal/vector/sqlitevec/helpers_test.go | 6 +- internal/vector/sqlitevec/sync.go | 120 +++++- internal/watcher/watcher.go | 10 +- internal/worker/handlers.go | 6 +- internal/worker/handlers_relations.go | 4 +- internal/worker/handlers_scoring.go | 13 +- internal/worker/handlers_test.go | 2 - internal/worker/sdk/parser_test.go | 4 +- internal/worker/sdk/processor.go | 7 +- internal/worker/sdk/processor_test.go | 20 +- internal/worker/sdk/prompts.go | 6 +- internal/worker/sdk/prompts_test.go | 6 +- internal/worker/service.go | 160 ++++---- internal/worker/session/manager.go | 43 +-- internal/worker/session/manager_test.go | 42 +-- internal/worker/sse/broadcaster.go | 2 +- internal/worker/sse/broadcaster_test.go | 2 +- pkg/hooks/response.go | 6 +- pkg/hooks/worker_test.go | 18 +- pkg/models/conflict.go | 12 +- pkg/models/conflict_test.go | 12 +- pkg/models/observation.go | 64 ++-- pkg/models/observation_test.go | 28 +- pkg/models/pattern.go | 50 +-- pkg/models/pattern_test.go | 16 +- pkg/models/prompt.go | 8 +- pkg/models/relation.go | 32 +- pkg/models/relation_test.go | 12 +- pkg/models/scoring.go | 30 +- pkg/models/session.go | 16 +- pkg/models/summary.go | 14 +- pkg/models/summary_test.go | 4 +- pkg/similarity/clustering_test.go | 2 +- ui/package-lock.json | 4 +- ui/package.json | 2 +- 78 files changed, 2313 insertions(+), 652 deletions(-) create mode 100644 internal/chunking/golang/chunker.go create mode 100644 internal/chunking/golang/chunker_test.go create mode 100644 internal/chunking/manager.go create mode 100644 internal/chunking/manager_test.go create mode 100644 internal/chunking/python/chunker.go create mode 100644 internal/chunking/types.go create mode 100644 internal/chunking/typescript/chunker.go diff --git a/cmd/hooks/session-start/main.go b/cmd/hooks/session-start/main.go index 7cb6d01..b12b0a3 100644 --- a/cmd/hooks/session-start/main.go +++ b/cmd/hooks/session-start/main.go @@ -23,12 +23,12 @@ type Input struct { // Observation represents an observation from the API. type Observation struct { - ID int64 `json:"id"` Type string `json:"type"` Title string `json:"title"` Subtitle string `json:"subtitle"` Narrative string `json:"narrative"` Facts []string `json:"facts"` + ID int64 `json:"id"` } func main() { @@ -46,8 +46,8 @@ func main() { } var input Input - if err := json.Unmarshal(inputData, &input); err != nil { - hooks.WriteError("SessionStart", err) + if unmarshalErr := json.Unmarshal(inputData, &input); unmarshalErr != nil { + hooks.WriteError("SessionStart", unmarshalErr) os.Exit(1) } diff --git a/cmd/hooks/statusline/main.go b/cmd/hooks/statusline/main.go index 6876d84..519afef 100644 --- a/cmd/hooks/statusline/main.go +++ b/cmd/hooks/statusline/main.go @@ -44,21 +44,21 @@ type StatusInput struct { // WorkerStats is the response from the worker's /api/stats endpoint. type WorkerStats struct { - Uptime string `json:"uptime"` - ActiveSessions int `json:"activeSessions"` - QueueDepth int `json:"queueDepth"` - IsProcessing bool `json:"isProcessing"` - ConnectedClients int `json:"connectedClients"` - SessionsToday int `json:"sessionsToday"` - Ready bool `json:"ready"` - Project string `json:"project,omitempty"` - ProjectObservations int `json:"projectObservations,omitempty"` - Retrieval struct { + Uptime string `json:"uptime"` + Project string `json:"project,omitempty"` + Retrieval struct { TotalRequests int64 `json:"TotalRequests"` ObservationsServed int64 `json:"ObservationsServed"` SearchRequests int64 `json:"SearchRequests"` ContextInjections int64 `json:"ContextInjections"` } `json:"retrieval"` + ActiveSessions int `json:"activeSessions"` + QueueDepth int `json:"queueDepth"` + ConnectedClients int `json:"connectedClients"` + SessionsToday int `json:"sessionsToday"` + ProjectObservations int `json:"projectObservations,omitempty"` + IsProcessing bool `json:"isProcessing"` + Ready bool `json:"ready"` } // ANSI color codes diff --git a/cmd/hooks/stop/main.go b/cmd/hooks/stop/main.go index cbb8861..88b9293 100644 --- a/cmd/hooks/stop/main.go +++ b/cmd/hooks/stop/main.go @@ -14,17 +14,17 @@ import ( // Input is the hook input from Claude Code. type Input struct { hooks.BaseInput - StopHookActive bool `json:"stop_hook_active"` TranscriptPath string `json:"transcript_path"` + StopHookActive bool `json:"stop_hook_active"` } // TranscriptMessage represents a message in the transcript JSONL file. type TranscriptMessage struct { - Type string `json:"type"` Message struct { + Content any `json:"content"` Role string `json:"role"` - Content any `json:"content"` // Can be string or array } `json:"message"` + Type string `json:"type"` // Can be string or array } // extractTextContent extracts text content from message content (handles both string and array formats). diff --git a/cmd/hooks/user-prompt/main.go b/cmd/hooks/user-prompt/main.go index 5b832f6..146d021 100644 --- a/cmd/hooks/user-prompt/main.go +++ b/cmd/hooks/user-prompt/main.go @@ -35,8 +35,8 @@ func main() { } var input Input - if err := json.Unmarshal(inputData, &input); err != nil { - hooks.WriteError("UserPromptSubmit", err) + if unmarshalErr := json.Unmarshal(inputData, &input); unmarshalErr != nil { + hooks.WriteError("UserPromptSubmit", unmarshalErr) os.Exit(1) } diff --git a/go.mod b/go.mod index 48aa104..b8f4cf1 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/goccy/go-json v0.10.5 github.com/mattn/go-sqlite3 v1.14.32 github.com/rs/zerolog v1.34.0 + github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 github.com/stretchr/testify v1.11.1 github.com/sugarme/tokenizer v0.3.0 github.com/yalue/onnxruntime_go v1.25.0 diff --git a/go.sum b/go.sum index f657cf0..a1cd6c1 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/schollz/progressbar/v2 v2.15.0 h1:dVzHQ8fHRmtPjD3K10jT3Qgn/+H+92jhPrhmxIJfDz8= github.com/schollz/progressbar/v2 v2.15.0/go.mod h1:UdPq3prGkfQ7MOzZKlDRpYKcFqEMczbD7YmbPgpzKMI= +github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4= +github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/internal/chunking/golang/chunker.go b/internal/chunking/golang/chunker.go new file mode 100644 index 0000000..c267cf5 --- /dev/null +++ b/internal/chunking/golang/chunker.go @@ -0,0 +1,285 @@ +// Package golang provides AST-aware chunking for Go source files. +package golang + +import ( + "context" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "strings" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +// Chunker implements AST-aware chunking for Go files. +type Chunker struct { + options chunking.ChunkOptions +} + +// NewChunker creates a new Go chunker. +func NewChunker(options chunking.ChunkOptions) *Chunker { + return &Chunker{options: options} +} + +// Language returns the language this chunker supports. +func (c *Chunker) Language() chunking.Language { + return chunking.LanguageGo +} + +// SupportedExtensions returns the file extensions this chunker handles. +func (c *Chunker) SupportedExtensions() []string { + return []string{".go"} +} + +// Chunk parses a Go source file and returns semantic code chunks. +func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) { + // Read file content + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("read file: %w", err) + } + + // Parse the Go file + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, content, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("parse Go file: %w", err) + } + + chunks := make([]chunking.Chunk, 0) + sourceLines := strings.Split(string(content), "\n") + + // Extract chunks from declarations + for _, decl := range file.Decls { + switch d := decl.(type) { + case *ast.FuncDecl: + chunk := c.extractFunction(fset, d, sourceLines, filePath) + if chunk != nil { + chunks = append(chunks, *chunk) + } + case *ast.GenDecl: + extracted := c.extractGenDecl(fset, d, sourceLines, filePath) + chunks = append(chunks, extracted...) + } + } + + return chunks, nil +} + +// extractFunction extracts a function or method declaration as a chunk. +func (c *Chunker) extractFunction(fset *token.FileSet, fn *ast.FuncDecl, sourceLines []string, filePath string) *chunking.Chunk { + // Skip unexported if configured + if !c.options.IncludePrivate && !fn.Name.IsExported() { + return nil + } + + startPos := fset.Position(fn.Pos()) + endPos := fset.Position(fn.End()) + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageGo, + Name: fn.Name.Name, + StartLine: startPos.Line, + EndLine: endPos.Line, + } + + // Determine if this is a method or a function + if fn.Recv != nil && len(fn.Recv.List) > 0 { + chunk.Type = chunking.ChunkTypeMethod + chunk.ParentName = c.extractReceiverType(fn.Recv) + } else { + chunk.Type = chunking.ChunkTypeFunction + } + + // Extract content + chunk.Content = c.extractLines(sourceLines, startPos.Line, endPos.Line) + + // Extract signature (function declaration without body) + chunk.Signature = c.extractFunctionSignature(fn, fset, sourceLines) + + // Extract doc comment + if c.options.IncludeDocComments && fn.Doc != nil { + chunk.DocComment = strings.TrimSpace(fn.Doc.Text()) + } + + return chunk +} + +// extractGenDecl extracts general declarations (type, const, var). +func (c *Chunker) extractGenDecl(fset *token.FileSet, gd *ast.GenDecl, sourceLines []string, filePath string) []chunking.Chunk { + var chunks []chunking.Chunk + + for _, spec := range gd.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + chunk := c.extractTypeSpec(fset, gd, s, sourceLines, filePath) + if chunk != nil { + chunks = append(chunks, *chunk) + } + case *ast.ValueSpec: + // Handle const and var declarations + chunk := c.extractValueSpec(fset, gd, s, sourceLines, filePath) + if chunk != nil { + chunks = append(chunks, *chunk) + } + } + } + + return chunks +} + +// extractTypeSpec extracts a type declaration (struct, interface, type alias). +func (c *Chunker) extractTypeSpec(fset *token.FileSet, gd *ast.GenDecl, ts *ast.TypeSpec, sourceLines []string, filePath string) *chunking.Chunk { + // Skip unexported if configured + if !c.options.IncludePrivate && !ts.Name.IsExported() { + return nil + } + + startPos := fset.Position(gd.Pos()) + endPos := fset.Position(gd.End()) + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageGo, + Name: ts.Name.Name, + StartLine: startPos.Line, + EndLine: endPos.Line, + Content: c.extractLines(sourceLines, startPos.Line, endPos.Line), + } + + // Determine chunk type based on type expression + switch ts.Type.(type) { + case *ast.StructType: + chunk.Type = chunking.ChunkTypeClass // Treat struct as class + case *ast.InterfaceType: + chunk.Type = chunking.ChunkTypeInterface + default: + chunk.Type = chunking.ChunkTypeType + } + + // Extract doc comment + if c.options.IncludeDocComments && gd.Doc != nil { + chunk.DocComment = strings.TrimSpace(gd.Doc.Text()) + } + + return chunk +} + +// extractValueSpec extracts const or var declarations. +func (c *Chunker) extractValueSpec(fset *token.FileSet, gd *ast.GenDecl, vs *ast.ValueSpec, sourceLines []string, filePath string) *chunking.Chunk { + // Skip if all names are unexported and we're excluding private + if !c.options.IncludePrivate { + allUnexported := true + for _, name := range vs.Names { + if name.IsExported() { + allUnexported = false + break + } + } + if allUnexported { + return nil + } + } + + startPos := fset.Position(gd.Pos()) + endPos := fset.Position(gd.End()) + + // Use first name as the chunk name, join multiple if present + names := make([]string, len(vs.Names)) + for i, name := range vs.Names { + names[i] = name.Name + } + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageGo, + Name: strings.Join(names, ", "), + StartLine: startPos.Line, + EndLine: endPos.Line, + Content: c.extractLines(sourceLines, startPos.Line, endPos.Line), + } + + // Set type based on token + if gd.Tok == token.CONST { + chunk.Type = chunking.ChunkTypeConst + } else { + chunk.Type = chunking.ChunkTypeVar + } + + // Extract doc comment + if c.options.IncludeDocComments && gd.Doc != nil { + chunk.DocComment = strings.TrimSpace(gd.Doc.Text()) + } + + return chunk +} + +// extractReceiverType extracts the receiver type name from a method. +func (c *Chunker) extractReceiverType(recv *ast.FieldList) string { + if len(recv.List) == 0 { + return "" + } + + field := recv.List[0] + switch t := field.Type.(type) { + case *ast.Ident: + return t.Name + case *ast.StarExpr: + if ident, ok := t.X.(*ast.Ident); ok { + return ident.Name + } + } + + return "" +} + +// extractFunctionSignature extracts the function signature without the body. +func (c *Chunker) extractFunctionSignature(fn *ast.FuncDecl, fset *token.FileSet, sourceLines []string) string { + if fn.Body == nil { + // No body, return entire declaration + startPos := fset.Position(fn.Pos()) + endPos := fset.Position(fn.End()) + return c.extractLines(sourceLines, startPos.Line, endPos.Line) + } + + // Extract from start of function to just before body + startPos := fset.Position(fn.Pos()) + bodyPos := fset.Position(fn.Body.Pos()) + + // If body is on the same line, extract just that line up to the opening brace + if startPos.Line == bodyPos.Line { + line := sourceLines[startPos.Line-1] + // Find the opening brace position + if idx := strings.Index(line[startPos.Column-1:], "{"); idx >= 0 { + return strings.TrimSpace(line[startPos.Column-1 : startPos.Column-1+idx]) + } + return strings.TrimSpace(line[startPos.Column-1:]) + } + + // Get lines from start to the line containing the opening brace + sig := c.extractLines(sourceLines, startPos.Line, bodyPos.Line) + // Remove the opening brace and anything after it + if idx := strings.Index(sig, "{"); idx >= 0 { + sig = sig[:idx] + } + return strings.TrimSpace(sig) +} + +// extractLines extracts a range of lines from source (1-indexed, inclusive). +func (c *Chunker) extractLines(lines []string, start, end int) string { + if start < 1 || end < start || start > len(lines) { + return "" + } + + // Adjust for 0-indexed array (start and end are 1-indexed) + startIdx := start - 1 + endIdx := end + if endIdx > len(lines) { + endIdx = len(lines) + } + + return strings.Join(lines[startIdx:endIdx], "\n") +} diff --git a/internal/chunking/golang/chunker_test.go b/internal/chunking/golang/chunker_test.go new file mode 100644 index 0000000..f09adc9 --- /dev/null +++ b/internal/chunking/golang/chunker_test.go @@ -0,0 +1,214 @@ +package golang + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +func TestGoChunker_BasicFunctions(t *testing.T) { + // Create temp test file + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + + testCode := `package main + +import "fmt" + +// Greet prints a greeting message +func Greet(name string) { + fmt.Printf("Hello, %s!\n", name) +} + +// Add adds two numbers +func Add(a, b int) int { + return a + b +} + +// unexported function should be included by default +func helper() string { + return "helper" +} +` + + if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Create chunker with default options + chunker := NewChunker(chunking.DefaultChunkOptions()) + + // Chunk the file + chunks, err := chunker.Chunk(context.Background(), testFile) + if err != nil { + t.Fatalf("Chunk() failed: %v", err) + } + + // Verify we got all functions + if len(chunks) != 3 { + t.Errorf("Expected 3 chunks (Greet, Add, helper), got %d", len(chunks)) + } + + // Verify chunk details + expectedNames := map[string]bool{ + "Greet": false, + "Add": false, + "helper": false, + } + + for _, chunk := range chunks { + if chunk.Type != chunking.ChunkTypeFunction { + t.Errorf("Expected chunk type 'function', got '%s'", chunk.Type) + } + + if chunk.Language != chunking.LanguageGo { + t.Errorf("Expected language 'go', got '%s'", chunk.Language) + } + + if _, ok := expectedNames[chunk.Name]; !ok { + t.Errorf("Unexpected function name: %s", chunk.Name) + } else { + expectedNames[chunk.Name] = true + } + + // Verify content is non-empty + if chunk.Content == "" { + t.Errorf("Chunk %s has empty content", chunk.Name) + } + + // Verify signature is present for functions + if chunk.Signature == "" { + t.Errorf("Chunk %s has empty signature", chunk.Name) + } + } + + // Verify all expected functions were found + for name, found := range expectedNames { + if !found { + t.Errorf("Expected function %s not found", name) + } + } +} + +func TestGoChunker_StructsAndMethods(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + + testCode := `package main + +// User represents a user +type User struct { + ID int + Name string +} + +// GetName returns the user's name +func (u *User) GetName() string { + return u.Name +} + +// SetName sets the user's name +func (u *User) SetName(name string) { + u.Name = name +} +` + + if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + chunker := NewChunker(chunking.DefaultChunkOptions()) + chunks, err := chunker.Chunk(context.Background(), testFile) + if err != nil { + t.Fatalf("Chunk() failed: %v", err) + } + + // Should have 1 struct + 2 methods = 3 chunks + if len(chunks) != 3 { + t.Errorf("Expected 3 chunks (User struct, GetName, SetName), got %d", len(chunks)) + } + + // Find the struct and methods + var structChunk, getNameChunk, setNameChunk *chunking.Chunk + for i := range chunks { + switch chunks[i].Name { + case "User": + structChunk = &chunks[i] + case "GetName": + getNameChunk = &chunks[i] + case "SetName": + setNameChunk = &chunks[i] + } + } + + // Verify struct + if structChunk == nil { + t.Fatal("User struct not found") + } + if structChunk.Type != chunking.ChunkTypeClass { + t.Errorf("Expected User to be ChunkTypeClass, got %s", structChunk.Type) + } + + // Verify methods + if getNameChunk == nil { + t.Fatal("GetName method not found") + } + if getNameChunk.Type != chunking.ChunkTypeMethod { + t.Errorf("Expected GetName to be ChunkTypeMethod, got %s", getNameChunk.Type) + } + if getNameChunk.ParentName != "User" { + t.Errorf("Expected GetName parent to be 'User', got '%s'", getNameChunk.ParentName) + } + + if setNameChunk == nil { + t.Fatal("SetName method not found") + } + if setNameChunk.Type != chunking.ChunkTypeMethod { + t.Errorf("Expected SetName to be ChunkTypeMethod, got %s", setNameChunk.Type) + } + if setNameChunk.ParentName != "User" { + t.Errorf("Expected SetName parent to be 'User', got '%s'", setNameChunk.ParentName) + } +} + +func TestGoChunker_DocComments(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + + testCode := `package main + +// Calculate performs a calculation. +// It takes two integers and returns their sum. +func Calculate(a, b int) int { + return a + b +} +` + + if err := os.WriteFile(testFile, []byte(testCode), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + chunker := NewChunker(chunking.DefaultChunkOptions()) + chunks, err := chunker.Chunk(context.Background(), testFile) + if err != nil { + t.Fatalf("Chunk() failed: %v", err) + } + + if len(chunks) != 1 { + t.Fatalf("Expected 1 chunk, got %d", len(chunks)) + } + + chunk := chunks[0] + if chunk.DocComment == "" { + t.Error("Expected doc comment to be present") + } + + // Doc comment should contain the comment text + expectedComment := "Calculate performs a calculation.\nIt takes two integers and returns their sum." + if chunk.DocComment != expectedComment { + t.Errorf("Expected doc comment '%s', got '%s'", expectedComment, chunk.DocComment) + } +} diff --git a/internal/chunking/manager.go b/internal/chunking/manager.go new file mode 100644 index 0000000..1115a54 --- /dev/null +++ b/internal/chunking/manager.go @@ -0,0 +1,106 @@ +package chunking + +import ( + "context" + "fmt" + "path/filepath" + "strings" +) + +// Manager dispatches files to appropriate language-specific chunkers. +type Manager struct { + chunkers map[string]Chunker // extension -> chunker + options ChunkOptions +} + +// NewManager creates a new chunking manager with the given chunkers. +func NewManager(chunkers []Chunker, options ChunkOptions) *Manager { + m := &Manager{ + chunkers: make(map[string]Chunker), + options: options, + } + + // Register chunkers by their supported extensions + for _, chunker := range chunkers { + for _, ext := range chunker.SupportedExtensions() { + m.chunkers[ext] = chunker + } + } + + return m +} + +// ChunkFile chunks a single file using the appropriate language chunker. +// Returns an error if no chunker is found for the file extension. +func (m *Manager) ChunkFile(ctx context.Context, filePath string) ([]Chunk, error) { + ext := strings.ToLower(filepath.Ext(filePath)) + chunker, ok := m.chunkers[ext] + if !ok { + return nil, fmt.Errorf("no chunker for extension %s", ext) + } + + chunks, err := chunker.Chunk(ctx, filePath) + if err != nil { + return nil, fmt.Errorf("chunk %s: %w", filePath, err) + } + + // Apply options-based filtering + filtered := make([]Chunk, 0, len(chunks)) + for _, chunk := range chunks { + // Filter by minimum lines + if m.options.MinLines > 0 { + lineCount := chunk.EndLine - chunk.StartLine + 1 + if lineCount < m.options.MinLines { + continue + } + } + + // Filter by maximum chunk size + if m.options.MaxChunkSize > 0 && len(chunk.Content) > m.options.MaxChunkSize { + // TODO: Consider splitting large chunks intelligently + // For now, skip chunks that are too large + continue + } + + filtered = append(filtered, chunk) + } + + return filtered, nil +} + +// ChunkFiles chunks multiple files in parallel. +// Returns a map of file path to chunks, and any errors encountered. +// Errors for individual files do not stop processing of other files. +func (m *Manager) ChunkFiles(ctx context.Context, filePaths []string) (map[string][]Chunk, []error) { + results := make(map[string][]Chunk) + var errors []error + + for _, filePath := range filePaths { + chunks, err := m.ChunkFile(ctx, filePath) + if err != nil { + errors = append(errors, fmt.Errorf("%s: %w", filePath, err)) + continue + } + if len(chunks) > 0 { + results[filePath] = chunks + } + } + + return results, errors +} + +// SupportsFile checks if the manager can chunk the given file based on extension. +func (m *Manager) SupportsFile(filePath string) bool { + ext := strings.ToLower(filepath.Ext(filePath)) + _, ok := m.chunkers[ext] + return ok +} + +// SupportedExtensions returns all file extensions supported by registered chunkers. +func (m *Manager) SupportedExtensions() []string { + exts := make([]string, 0, len(m.chunkers)) + for ext := range m.chunkers { + exts = append(exts, ext) + } + return exts +} diff --git a/internal/chunking/manager_test.go b/internal/chunking/manager_test.go new file mode 100644 index 0000000..6c8e867 --- /dev/null +++ b/internal/chunking/manager_test.go @@ -0,0 +1,162 @@ +package chunking + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +// mockChunker is a test chunker that returns dummy chunks +type mockChunker struct{} + +func (m *mockChunker) Chunk(ctx context.Context, filePath string) ([]Chunk, error) { + // Just return an empty chunk for testing + return []Chunk{ + { + FilePath: filePath, + Language: LanguageGo, + Type: ChunkTypeFunction, + Name: "TestFunc", + StartLine: 1, + EndLine: 1, + Content: "test", + }, + }, nil +} + +func (m *mockChunker) Language() Language { + return LanguageGo +} + +func (m *mockChunker) SupportedExtensions() []string { + return []string{".go", ".py", ".ts"} +} + +func TestManager_ChunkMultipleFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create a Go file + goFile := filepath.Join(tmpDir, "test.go") + goCode := `package main + +func Hello() string { + return "hello" +} +` + if err := os.WriteFile(goFile, []byte(goCode), 0600); err != nil { + t.Fatalf("Failed to create Go file: %v", err) + } + + // Create a Python file + pyFile := filepath.Join(tmpDir, "test.py") + pyCode := `def greet(name): + return f"Hello, {name}!" + +class User: + def __init__(self, name): + self.name = name +` + if err := os.WriteFile(pyFile, []byte(pyCode), 0600); err != nil { + t.Fatalf("Failed to create Python file: %v", err) + } + + // Create a TypeScript file + tsFile := filepath.Join(tmpDir, "test.ts") + tsCode := `function add(a: number, b: number): number { + return a + b; +} + +class Calculator { + multiply(a: number, b: number): number { + return a * b; + } +} +` + if err := os.WriteFile(tsFile, []byte(tsCode), 0600); err != nil { + t.Fatalf("Failed to create TypeScript file: %v", err) + } + + // Create manager + manager := NewManager([]Chunker{&mockChunker{}}, DefaultChunkOptions()) + + // Test SupportsFile + if !manager.SupportsFile(goFile) { + t.Error("Manager should support .go files") + } + if !manager.SupportsFile(pyFile) { + t.Error("Manager should support .py files") + } + if !manager.SupportsFile(tsFile) { + t.Error("Manager should support .ts files") + } + + unsupportedFile := filepath.Join(tmpDir, "test.txt") + if manager.SupportsFile(unsupportedFile) { + t.Error("Manager should not support .txt files") + } + + // Test ChunkFiles + results, errs := manager.ChunkFiles(context.Background(), []string{goFile, pyFile, tsFile}) + if len(errs) > 0 { + t.Errorf("ChunkFiles returned errors: %v", errs) + } + + if len(results) != 3 { + t.Errorf("Expected results for 3 files, got %d", len(results)) + } + + // Verify each file has chunks + for _, file := range []string{goFile, pyFile, tsFile} { + if chunks, ok := results[file]; !ok || len(chunks) == 0 { + t.Errorf("No chunks found for file %s", file) + } + } +} + +// mockChunkerWithExts is a test chunker with configurable extensions +type mockChunkerWithExts struct { + exts []string +} + +func (m *mockChunkerWithExts) Chunk(ctx context.Context, filePath string) ([]Chunk, error) { + return nil, nil +} + +func (m *mockChunkerWithExts) Language() Language { + return LanguageGo +} + +func (m *mockChunkerWithExts) SupportedExtensions() []string { + return m.exts +} + +func TestManager_SupportedExtensions(t *testing.T) { + + // Create manager with mock chunkers + manager := NewManager([]Chunker{ + &mockChunkerWithExts{exts: []string{".go"}}, + &mockChunkerWithExts{exts: []string{".py", ".pyw"}}, + }, DefaultChunkOptions()) + + exts := manager.SupportedExtensions() + expectedExts := map[string]bool{ + ".go": false, + ".py": false, + ".pyw": false, + } + + for _, ext := range exts { + if _, ok := expectedExts[ext]; ok { + expectedExts[ext] = true + } else { + t.Errorf("Unexpected extension: %s", ext) + } + } + + for ext, found := range expectedExts { + if !found { + t.Errorf("Expected extension %s not found", ext) + } + } +} diff --git a/internal/chunking/python/chunker.go b/internal/chunking/python/chunker.go new file mode 100644 index 0000000..c4906f9 --- /dev/null +++ b/internal/chunking/python/chunker.go @@ -0,0 +1,291 @@ +// Package python provides AST-aware chunking for Python source files using tree-sitter. +package python + +import ( + "context" + "fmt" + "os" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/python" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +// Chunker implements AST-aware chunking for Python files. +type Chunker struct { + parser *sitter.Parser + options chunking.ChunkOptions +} + +// NewChunker creates a new Python chunker. +func NewChunker(options chunking.ChunkOptions) *Chunker { + parser := sitter.NewParser() + parser.SetLanguage(python.GetLanguage()) + + return &Chunker{ + options: options, + parser: parser, + } +} + +// Language returns the language this chunker supports. +func (c *Chunker) Language() chunking.Language { + return chunking.LanguagePython +} + +// SupportedExtensions returns the file extensions this chunker handles. +func (c *Chunker) SupportedExtensions() []string { + return []string{".py"} +} + +// Chunk parses a Python source file and returns semantic code chunks. +func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) { + // Read file content + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("read file: %w", err) + } + + // Parse the Python file + tree, err := c.parser.ParseCtx(ctx, nil, content) + if err != nil { + return nil, fmt.Errorf("parse Python file: %w", err) + } + defer tree.Close() + + sourceLines := strings.Split(string(content), "\n") + chunks := make([]chunking.Chunk, 0) + + // Walk the AST and extract chunks + c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks) + + return chunks, nil +} + +// walkNode recursively walks the tree-sitter AST and extracts chunks. +func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) { + nodeType := node.Type() + + switch nodeType { + case "function_definition": + chunk := c.extractFunction(node, source, sourceLines, filePath, parentName) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + + case "class_definition": + chunk := c.extractClass(node, source, sourceLines, filePath) + if chunk != nil { + *chunks = append(*chunks, *chunk) + + // Walk class body to find methods + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "block" { + c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks) + } + } + } + return // Don't walk children again + + case "block": + // Walk statements in block + for i := 0; i < int(node.ChildCount()); i++ { + c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks) + } + return + } + + // Walk all children + for i := 0; i < int(node.ChildCount()); i++ { + c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks) + } +} + +// extractFunction extracts a function definition chunk. +func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk { + // Find function name + var nameNode *sitter.Node + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "identifier" { + nameNode = child + break + } + } + + if nameNode == nil { + return nil + } + + name := nameNode.Content(source) + + // Skip private functions if configured + if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguagePython, + Name: name, + ParentName: parentName, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + } + + // Determine if this is a method or function + if parentName != "" { + chunk.Type = chunking.ChunkTypeMethod + } else { + chunk.Type = chunking.ChunkTypeFunction + } + + // Extract signature (def line) + chunk.Signature = c.extractFunctionSignature(node, source, sourceLines) + + // Extract docstring as doc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractDocstring(node, source) + } + + return chunk +} + +// extractClass extracts a class definition chunk. +func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk { + // Find class name + var nameNode *sitter.Node + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "identifier" { + nameNode = child + break + } + } + + if nameNode == nil { + return nil + } + + name := nameNode.Content(source) + + // Skip private classes if configured + if !c.options.IncludePrivate && strings.HasPrefix(name, "_") && !strings.HasPrefix(name, "__") { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguagePython, + Type: chunking.ChunkTypeClass, + Name: name, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + } + + // Extract class signature (class line) + chunk.Signature = c.extractClassSignature(node, source, sourceLines) + + // Extract docstring as doc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractDocstring(node, source) + } + + return chunk +} + +// extractFunctionSignature extracts the function definition line. +func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string { + startLine := int(node.StartPoint().Row) + 1 + + // Find the colon that ends the signature + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == ":" { + endLine := int(child.EndPoint().Row) + 1 + return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine)) + } + } + + // Fallback: just return first line + return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine)) +} + +// extractClassSignature extracts the class definition line. +func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string { + startLine := int(node.StartPoint().Row) + 1 + + // Find the colon that ends the signature + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == ":" { + endLine := int(child.EndPoint().Row) + 1 + return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine)) + } + } + + // Fallback: just return first line + return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine)) +} + +// extractDocstring extracts the docstring from a function or class. +func (c *Chunker) extractDocstring(node *sitter.Node, source []byte) string { + // Find the block + var blockNode *sitter.Node + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "block" { + blockNode = child + break + } + } + + if blockNode == nil { + return "" + } + + // Check if first statement in block is a string (docstring) + for i := 0; i < int(blockNode.ChildCount()); i++ { + child := blockNode.Child(i) + if child.Type() == "expression_statement" { + // Check if it contains a string + for j := 0; j < int(child.ChildCount()); j++ { + grandchild := child.Child(j) + if grandchild.Type() == "string" { + docstring := grandchild.Content(source) + // Remove quotes + docstring = strings.Trim(docstring, `"'`) + return strings.TrimSpace(docstring) + } + } + } + } + + return "" +} + +// extractLines extracts a range of lines from source (1-indexed, inclusive). +func (c *Chunker) extractLines(lines []string, start, end int) string { + if start < 1 || end < start || start > len(lines) { + return "" + } + + startIdx := start - 1 + endIdx := end + if endIdx > len(lines) { + endIdx = len(lines) + } + + return strings.Join(lines[startIdx:endIdx], "\n") +} diff --git a/internal/chunking/types.go b/internal/chunking/types.go new file mode 100644 index 0000000..ce639b1 --- /dev/null +++ b/internal/chunking/types.go @@ -0,0 +1,140 @@ +// Package chunking provides AST-aware code chunking for semantic code search. +// Chunks code files into logical units (functions, classes, methods) that preserve +// semantic boundaries for better vector embedding and retrieval. +package chunking + +import ( + "context" + "fmt" + "strings" +) + +// ChunkType represents the type of code chunk. +type ChunkType string + +const ( + // ChunkTypeFunction represents a standalone function. + ChunkTypeFunction ChunkType = "function" + // ChunkTypeMethod represents a method on a class/struct/type. + ChunkTypeMethod ChunkType = "method" + // ChunkTypeClass represents a class or struct definition. + ChunkTypeClass ChunkType = "class" + // ChunkTypeInterface represents an interface definition. + ChunkTypeInterface ChunkType = "interface" + // ChunkTypeType represents a type alias or type definition. + ChunkTypeType ChunkType = "type" + // ChunkTypeConst represents constant declarations. + ChunkTypeConst ChunkType = "const" + // ChunkTypeVar represents variable declarations. + ChunkTypeVar ChunkType = "var" +) + +// Language represents a programming language. +type Language string + +const ( + // LanguageGo represents the Go programming language. + LanguageGo Language = "go" + // LanguagePython represents the Python programming language. + LanguagePython Language = "python" + // LanguageTypeScript represents the TypeScript programming language. + LanguageTypeScript Language = "typescript" + // LanguageJavaScript represents the JavaScript programming language. + LanguageJavaScript Language = "javascript" +) + +// Chunk represents a semantic code chunk with AST-derived boundaries. +type Chunk struct { + Metadata map[string]interface{} + FilePath string + Language Language + Type ChunkType + Name string + ParentName string + Content string + Signature string + DocComment string + StartLine int + EndLine int +} + +// Identifier returns a human-readable identifier for this chunk. +// Format: "ParentName.Name" for methods, "Name" for top-level. +func (c *Chunk) Identifier() string { + if c.ParentName != "" { + return fmt.Sprintf("%s.%s", c.ParentName, c.Name) + } + return c.Name +} + +// LineRange returns a human-readable line range. +// Format: "L123-L456" +func (c *Chunk) LineRange() string { + return fmt.Sprintf("L%d-L%d", c.StartLine, c.EndLine) +} + +// SearchableContent returns content optimized for semantic search. +// Combines signature, doc comment, and content in a structured format. +func (c *Chunk) SearchableContent() string { + var parts []string + + // Include signature for functions/methods + if c.Signature != "" { + parts = append(parts, c.Signature) + } + + // Include doc comment + if c.DocComment != "" { + parts = append(parts, c.DocComment) + } + + // Include actual content + if c.Content != "" { + parts = append(parts, c.Content) + } + + return strings.Join(parts, "\n\n") +} + +// Chunker is the interface for language-specific code chunkers. +type Chunker interface { + // Chunk parses a source file and returns semantic code chunks. + // Returns an error if the file cannot be parsed or read. + Chunk(ctx context.Context, filePath string) ([]Chunk, error) + + // Language returns the language this chunker supports. + Language() Language + + // SupportedExtensions returns file extensions this chunker handles. + // Example: []string{".go"} for Go chunker + SupportedExtensions() []string +} + +// ChunkOptions provides options for chunking behavior. +type ChunkOptions struct { + // MaxChunkSize is the maximum size of a chunk in bytes. + // Chunks larger than this will be split (respecting boundaries where possible). + // 0 means no limit. + MaxChunkSize int + + // IncludeDocComments controls whether to include documentation comments. + IncludeDocComments bool + + // IncludePrivate controls whether to include private/unexported symbols. + IncludePrivate bool + + // MinLines is the minimum number of lines for a chunk to be included. + // Chunks smaller than this will be skipped. + // 0 means no minimum. + MinLines int +} + +// DefaultChunkOptions returns sensible default options. +func DefaultChunkOptions() ChunkOptions { + return ChunkOptions{ + MaxChunkSize: 8192, // ~8KB per chunk (well under token limit) + IncludeDocComments: true, + IncludePrivate: true, // Include all symbols for comprehensive search + MinLines: 0, // No minimum - include even single-line functions + } +} diff --git a/internal/chunking/typescript/chunker.go b/internal/chunking/typescript/chunker.go new file mode 100644 index 0000000..44029cd --- /dev/null +++ b/internal/chunking/typescript/chunker.go @@ -0,0 +1,403 @@ +// Package typescript provides AST-aware chunking for TypeScript and JavaScript source files using tree-sitter. +package typescript + +import ( + "context" + "fmt" + "os" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/typescript/typescript" + + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" +) + +// Chunker implements AST-aware chunking for TypeScript/JavaScript files. +type Chunker struct { + parser *sitter.Parser + options chunking.ChunkOptions +} + +// NewChunker creates a new TypeScript chunker. +func NewChunker(options chunking.ChunkOptions) *Chunker { + parser := sitter.NewParser() + parser.SetLanguage(typescript.GetLanguage()) + + return &Chunker{ + options: options, + parser: parser, + } +} + +// Language returns the language this chunker supports. +func (c *Chunker) Language() chunking.Language { + return chunking.LanguageTypeScript +} + +// SupportedExtensions returns the file extensions this chunker handles. +func (c *Chunker) SupportedExtensions() []string { + return []string{".ts", ".tsx", ".js", ".jsx"} +} + +// Chunk parses a TypeScript/JavaScript source file and returns semantic code chunks. +func (c *Chunker) Chunk(ctx context.Context, filePath string) ([]chunking.Chunk, error) { + // Read file content + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("read file: %w", err) + } + + // Parse the file + tree, err := c.parser.ParseCtx(ctx, nil, content) + if err != nil { + return nil, fmt.Errorf("parse TypeScript file: %w", err) + } + defer tree.Close() + + sourceLines := strings.Split(string(content), "\n") + chunks := make([]chunking.Chunk, 0) + + // Walk the AST and extract chunks + c.walkNode(tree.RootNode(), content, sourceLines, filePath, "", &chunks) + + return chunks, nil +} + +// walkNode recursively walks the tree-sitter AST and extracts chunks. +func (c *Chunker) walkNode(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string, chunks *[]chunking.Chunk) { + nodeType := node.Type() + + switch nodeType { + case "function_declaration": + chunk := c.extractFunction(node, source, sourceLines, filePath, parentName) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + + case "method_definition": + chunk := c.extractMethod(node, source, sourceLines, filePath, parentName) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + + case "arrow_function", "function_expression": + // Handle arrow functions and function expressions assigned to variables + chunk := c.extractFunctionExpression(node, source, sourceLines, filePath, parentName) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + + case "class_declaration": + chunk := c.extractClass(node, source, sourceLines, filePath) + if chunk != nil { + *chunks = append(*chunks, *chunk) + + // Walk class body to find methods + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "class_body" { + c.walkNode(child, source, sourceLines, filePath, chunk.Name, chunks) + } + } + } + return // Don't walk children again + + case "interface_declaration": + chunk := c.extractInterface(node, source, sourceLines, filePath) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + + case "type_alias_declaration": + chunk := c.extractTypeAlias(node, source, sourceLines, filePath) + if chunk != nil { + *chunks = append(*chunks, *chunk) + } + } + + // Walk all children + for i := 0; i < int(node.ChildCount()); i++ { + c.walkNode(node.Child(i), source, sourceLines, filePath, parentName, chunks) + } +} + +// extractFunction extracts a function declaration. +func (c *Chunker) extractFunction(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk { + name := c.findChildContent(node, "identifier", source) + if name == "" { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeFunction, + Name: name, + ParentName: parentName, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + Signature: c.extractFunctionSignature(node, source, sourceLines), + } + + // Extract JSDoc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractComment(node, source) + } + + return chunk +} + +// extractMethod extracts a method definition from a class. +func (c *Chunker) extractMethod(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk { + name := c.findChildContent(node, "property_identifier", source) + if name == "" { + return nil + } + + // Skip private methods if configured + if !c.options.IncludePrivate && strings.HasPrefix(name, "_") { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeMethod, + Name: name, + ParentName: parentName, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + Signature: c.extractMethodSignature(node, source, sourceLines), + } + + // Extract JSDoc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractComment(node, source) + } + + return chunk +} + +// extractFunctionExpression extracts arrow functions and function expressions. +func (c *Chunker) extractFunctionExpression(node *sitter.Node, source []byte, sourceLines []string, filePath string, parentName string) *chunking.Chunk { + // Try to find the variable name from parent + parent := node.Parent() + if parent == nil { + return nil + } + + var name string + if parent.Type() == "variable_declarator" { + name = c.findChildContent(parent, "identifier", source) + } else if parent.Type() == "assignment_expression" { + // Handle const foo = () => {} + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "identifier" || child.Type() == "member_expression" { + name = child.Content(source) + break + } + } + } + + if name == "" { + return nil // Anonymous function, skip + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeFunction, + Name: name, + ParentName: parentName, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + } + + return chunk +} + +// extractClass extracts a class declaration. +func (c *Chunker) extractClass(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk { + name := c.findChildContent(node, "type_identifier", source) + if name == "" { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeClass, + Name: name, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + Signature: c.extractClassSignature(node, source, sourceLines), + } + + // Extract JSDoc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractComment(node, source) + } + + return chunk +} + +// extractInterface extracts an interface declaration. +func (c *Chunker) extractInterface(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk { + name := c.findChildContent(node, "type_identifier", source) + if name == "" { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeInterface, + Name: name, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + } + + // Extract JSDoc comment + if c.options.IncludeDocComments { + chunk.DocComment = c.extractComment(node, source) + } + + return chunk +} + +// extractTypeAlias extracts a type alias declaration. +func (c *Chunker) extractTypeAlias(node *sitter.Node, source []byte, sourceLines []string, filePath string) *chunking.Chunk { + name := c.findChildContent(node, "type_identifier", source) + if name == "" { + return nil + } + + startLine := int(node.StartPoint().Row) + 1 + endLine := int(node.EndPoint().Row) + 1 + + chunk := &chunking.Chunk{ + FilePath: filePath, + Language: chunking.LanguageTypeScript, + Type: chunking.ChunkTypeType, + Name: name, + StartLine: startLine, + EndLine: endLine, + Content: c.extractLines(sourceLines, startLine, endLine), + } + + return chunk +} + +// findChildContent finds the first child of the given type and returns its content. +func (c *Chunker) findChildContent(node *sitter.Node, childType string, source []byte) string { + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == childType { + return child.Content(source) + } + } + return "" +} + +// extractFunctionSignature extracts the function signature. +func (c *Chunker) extractFunctionSignature(node *sitter.Node, source []byte, sourceLines []string) string { + startLine := int(node.StartPoint().Row) + 1 + + // Find the opening brace of the body + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "statement_block" { + endLine := int(child.StartPoint().Row) + 1 + return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1)) + } + } + + // Fallback: just return first line + return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine)) +} + +// extractMethodSignature extracts the method signature. +func (c *Chunker) extractMethodSignature(node *sitter.Node, source []byte, sourceLines []string) string { + startLine := int(node.StartPoint().Row) + 1 + + // Find the opening brace of the body + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "statement_block" { + endLine := int(child.StartPoint().Row) + 1 + return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1)) + } + } + + return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine)) +} + +// extractClassSignature extracts the class declaration line. +func (c *Chunker) extractClassSignature(node *sitter.Node, source []byte, sourceLines []string) string { + startLine := int(node.StartPoint().Row) + 1 + + // Find the opening brace of the class body + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "class_body" { + endLine := int(child.StartPoint().Row) + 1 + return strings.TrimSpace(c.extractLines(sourceLines, startLine, endLine-1)) + } + } + + return strings.TrimSpace(c.extractLines(sourceLines, startLine, startLine)) +} + +// extractComment extracts JSDoc or other comments from a node. +func (c *Chunker) extractComment(node *sitter.Node, source []byte) string { + // Check previous sibling for comment + prevSibling := node.PrevSibling() + if prevSibling != nil && prevSibling.Type() == "comment" { + comment := prevSibling.Content(source) + // Remove comment markers + comment = strings.TrimPrefix(comment, "/**") + comment = strings.TrimPrefix(comment, "/*") + comment = strings.TrimSuffix(comment, "*/") + comment = strings.TrimPrefix(comment, "//") + return strings.TrimSpace(comment) + } + + return "" +} + +// extractLines extracts a range of lines from source (1-indexed, inclusive). +func (c *Chunker) extractLines(lines []string, start, end int) string { + if start < 1 || end < start || start > len(lines) { + return "" + } + + startIdx := start - 1 + endIdx := end + if endIdx > len(lines) { + endIdx = len(lines) + } + + return strings.Join(lines[startIdx:endIdx], "\n") +} diff --git a/internal/config/config.go b/internal/config/config.go index 13f952e..c07686b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,40 +37,29 @@ var CriticalConcepts = []string{ // Config holds the application configuration. type Config struct { - // Worker settings - WorkerPort int `json:"worker_port"` - - // Database settings - DBPath string `json:"db_path"` - MaxConns int `json:"max_conns"` - - // SDK Agent settings - Model string `json:"model"` - ClaudeCodePath string `json:"claude_code_path"` - - // Embedding settings - EmbeddingModel string `json:"embedding_model"` // e.g., "bge-v1.5" - - // Reranking settings (cross-encoder) - RerankingEnabled bool `json:"reranking_enabled"` // Enable cross-encoder reranking - RerankingCandidates int `json:"reranking_candidates"` // Number of candidates to retrieve before reranking (default 100) - RerankingResults int `json:"reranking_results"` // Number of results to return after reranking (default 10) - RerankingAlpha float64 `json:"reranking_alpha"` // Weight for combining scores: alpha*rerank + (1-alpha)*original (default 0.7) - RerankingMinImprovement float64 `json:"reranking_min_improvement"` // Minimum rank improvement to trigger reranking (default 0, always rerank) - RerankingPureMode bool `json:"reranking_pure_mode"` // Use pure cross-encoder scores without combining with bi-encoder (default false) - - // Context injection settings + ContextFullField string `json:"context_full_field"` + DBPath string `json:"db_path"` + Model string `json:"model"` + ClaudeCodePath string `json:"claude_code_path"` + EmbeddingModel string `json:"embedding_model"` + ContextObsConcepts []string `json:"context_obs_concepts"` + ContextObsTypes []string `json:"context_obs_types"` + RerankingMinImprovement float64 `json:"reranking_min_improvement"` + RerankingCandidates int `json:"reranking_candidates"` + RerankingAlpha float64 `json:"reranking_alpha"` + WorkerPort int `json:"worker_port"` + ContextMaxPromptResults int `json:"context_max_prompt_results"` ContextObservations int `json:"context_observations"` ContextFullCount int `json:"context_full_count"` ContextSessionCount int `json:"context_session_count"` - ContextShowReadTokens bool `json:"context_show_read_tokens"` - ContextShowWorkTokens bool `json:"context_show_work_tokens"` - ContextFullField string `json:"context_full_field"` + ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` + MaxConns int `json:"max_conns"` + RerankingResults int `json:"reranking_results"` ContextShowLastSummary bool `json:"context_show_last_summary"` - ContextObsTypes []string `json:"context_obs_types"` - ContextObsConcepts []string `json:"context_obs_concepts"` - ContextRelevanceThreshold float64 `json:"context_relevance_threshold"` // 0.0-1.0, minimum similarity for inclusion - ContextMaxPromptResults int `json:"context_max_prompt_results"` // Max results per prompt (0 = threshold only) + RerankingEnabled bool `json:"reranking_enabled"` + ContextShowWorkTokens bool `json:"context_show_work_tokens"` + ContextShowReadTokens bool `json:"context_show_read_tokens"` + RerankingPureMode bool `json:"reranking_pure_mode"` } var ( diff --git a/internal/config/config_test.go b/internal/config/config_test.go index bc8bc4e..02dfd0c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -121,8 +121,8 @@ func (s *ConfigSuite) TestLoad_TableDriven() { tests := []struct { name string settingsJSON string - expectedPort int expectedModel string + expectedPort int expectedObsObs int }{ { @@ -183,12 +183,12 @@ func (s *ConfigSuite) TestLoad_TableDriven() { s.Require().NoError(err) if tt.settingsJSON != "" { - err := os.WriteFile( + writeErr := os.WriteFile( filepath.Join(tempDir, ".claude-mnemonic", "settings.json"), []byte(tt.settingsJSON), 0600, ) - s.Require().NoError(err) + s.Require().NoError(writeErr) } cfg, err := Load() diff --git a/internal/db/sqlite/conflict.go b/internal/db/sqlite/conflict.go index 4cd6673..959b7c9 100644 --- a/internal/db/sqlite/conflict.go +++ b/internal/db/sqlite/conflict.go @@ -182,13 +182,13 @@ func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, proje var toDelete []int64 for rows.Next() { var id int64 - if err := rows.Scan(&id); err != nil { - return nil, err + if scanErr := rows.Scan(&id); scanErr != nil { + return nil, scanErr } toDelete = append(toDelete, id) } - if err := rows.Err(); err != nil { - return nil, err + if rowsErr := rows.Err(); rowsErr != nil { + return nil, rowsErr } if len(toDelete) == 0 { @@ -197,8 +197,8 @@ func (s *ConflictStore) CleanupSupersededObservations(ctx context.Context, proje // Delete the conflict records first (due to foreign key constraints) for _, obsID := range toDelete { - if err := s.DeleteConflictsByObservationID(ctx, obsID); err != nil { - return nil, err + if delErr := s.DeleteConflictsByObservationID(ctx, obsID); delErr != nil { + return nil, delErr } } diff --git a/internal/db/sqlite/helpers_test.go b/internal/db/sqlite/helpers_test.go index da7e99f..1d6433a 100644 --- a/internal/db/sqlite/helpers_test.go +++ b/internal/db/sqlite/helpers_test.go @@ -54,14 +54,14 @@ func TestNullInt(t *testing.T) { func TestRepeatPlaceholders(t *testing.T) { tests := []struct { name string - n int expected string + n int }{ - {"zero", 0, ""}, - {"negative", -1, ""}, - {"one", 1, ", ?"}, - {"two", 2, ", ?, ?"}, - {"three", 3, ", ?, ?, ?"}, + {name: "zero", n: 0, expected: ""}, + {name: "negative", n: -1, expected: ""}, + {name: "one", n: 1, expected: ", ?"}, + {name: "two", n: 2, expected: ", ?, ?"}, + {name: "three", n: 3, expected: ", ?, ?, ?"}, } for _, tt := range tests { diff --git a/internal/db/sqlite/migrations.go b/internal/db/sqlite/migrations.go index 0bb3bc6..b2d6b55 100644 --- a/internal/db/sqlite/migrations.go +++ b/internal/db/sqlite/migrations.go @@ -9,9 +9,9 @@ import ( // Migration represents a database schema migration. type Migration struct { - Version int Name string SQL string + Version int } // Migrations is the list of all database migrations in order. @@ -539,11 +539,11 @@ func (m *MigrationManager) ApplyMigration(migration Migration) error { if err != nil { return fmt.Errorf("begin transaction: %w", err) } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() // Execute migration SQL - if _, err := tx.Exec(migration.SQL); err != nil { - return fmt.Errorf("execute migration %d (%s): %w", migration.Version, migration.Name, err) + if _, execErr := tx.Exec(migration.SQL); execErr != nil { + return fmt.Errorf("execute migration %d (%s): %w", migration.Version, migration.Name, execErr) } // Record migration diff --git a/internal/db/sqlite/observation.go b/internal/db/sqlite/observation.go index 78f56ef..d7874c3 100644 --- a/internal/db/sqlite/observation.go +++ b/internal/db/sqlite/observation.go @@ -585,13 +585,13 @@ func (s *ObservationStore) CleanupOldObservations(ctx context.Context, project s var toDelete []int64 for rows.Next() { var id int64 - if err := rows.Scan(&id); err != nil { - return nil, err + if scanErr := rows.Scan(&id); scanErr != nil { + return nil, scanErr } toDelete = append(toDelete, id) } - if err := rows.Err(); err != nil { - return nil, err + if rowsErr := rows.Err(); rowsErr != nil { + return nil, rowsErr } if len(toDelete) == 0 { diff --git a/internal/db/sqlite/observation_test.go b/internal/db/sqlite/observation_test.go index 6b4972c..3d6a6ca 100644 --- a/internal/db/sqlite/observation_test.go +++ b/internal/db/sqlite/observation_test.go @@ -65,10 +65,10 @@ func (s *ObservationStoreSuite) TestStoreObservation_TableDriven() { ctx := context.Background() tests := []struct { + obs *models.ParsedObservation name string sdkSessionID string project string - obs *models.ParsedObservation promptNum int tokens int64 wantErr bool @@ -308,8 +308,8 @@ func (s *ObservationStoreSuite) TestGetObservationsByIDs() { tests := []struct { name string - queryIDs []int64 orderBy string + queryIDs []int64 limit int wantCount int }{ diff --git a/internal/db/sqlite/prompt.go b/internal/db/sqlite/prompt.go index 2d44661..faa65a8 100644 --- a/internal/db/sqlite/prompt.go +++ b/internal/db/sqlite/prompt.go @@ -102,13 +102,13 @@ func (s *PromptStore) CleanupOldPrompts(ctx context.Context) ([]int64, error) { var toDelete []int64 for rows.Next() { var id int64 - if err := rows.Scan(&id); err != nil { - return nil, err + if scanErr := rows.Scan(&id); scanErr != nil { + return nil, scanErr } toDelete = append(toDelete, id) } - if err := rows.Err(); err != nil { - return nil, err + if rowsErr := rows.Err(); rowsErr != nil { + return nil, rowsErr } if len(toDelete) == 0 { diff --git a/internal/db/sqlite/prompt_test.go b/internal/db/sqlite/prompt_test.go index 6d16093..1237a70 100644 --- a/internal/db/sqlite/prompt_test.go +++ b/internal/db/sqlite/prompt_test.go @@ -261,14 +261,14 @@ func TestPromptStore_SaveMultiplePrompts(t *testing.T) { tests := []struct { claudeSessionID string - promptNum int text string + promptNum int matches int }{ - {"claude-1", 1, "First prompt", 5}, - {"claude-1", 2, "Second prompt", 3}, - {"claude-2", 1, "Third prompt", 0}, - {"claude-1", 3, "Fourth prompt", 10}, + {claudeSessionID: "claude-1", promptNum: 1, text: "First prompt", matches: 5}, + {claudeSessionID: "claude-1", promptNum: 2, text: "Second prompt", matches: 3}, + {claudeSessionID: "claude-2", promptNum: 1, text: "Third prompt", matches: 0}, + {claudeSessionID: "claude-1", promptNum: 3, text: "Fourth prompt", matches: 10}, } for _, tt := range tests { diff --git a/internal/db/sqlite/scoring.go b/internal/db/sqlite/scoring.go index 065f1c0..a594ca1 100644 --- a/internal/db/sqlite/scoring.go +++ b/internal/db/sqlite/scoring.go @@ -71,7 +71,7 @@ func (s *ObservationStore) UpdateImportanceScores(ctx context.Context, scores ma if err != nil { return err } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() now := time.Now().UnixMilli() stmt, err := tx.PrepareContext(ctx, ` @@ -171,7 +171,7 @@ func (s *ObservationStore) UpdateConceptWeights(ctx context.Context, weights map if err != nil { return err } - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() stmt, err := tx.PrepareContext(ctx, ` INSERT INTO concept_weights (concept, weight, updated_at) diff --git a/internal/db/sqlite/session_test.go b/internal/db/sqlite/session_test.go index b445ed1..d286882 100644 --- a/internal/db/sqlite/session_test.go +++ b/internal/db/sqlite/session_test.go @@ -211,8 +211,8 @@ func (s *SessionStoreSuite) TestFindAnySDKSession_Scenarios() { tests := []struct { name string claudeSessionID string - wantFound bool wantProject string + wantFound bool }{ { name: "find existing session 1", diff --git a/internal/db/sqlite/store.go b/internal/db/sqlite/store.go index 3f9aa51..e6ef989 100644 --- a/internal/db/sqlite/store.go +++ b/internal/db/sqlite/store.go @@ -94,8 +94,8 @@ func (s *Store) GetStmt(query string) (*sql.Stmt, error) { defer s.stmtMu.Unlock() // Double-check after acquiring write lock - if stmt, ok := s.stmtCache[query]; ok { - return stmt, nil + if cachedStmt, ok := s.stmtCache[query]; ok { + return cachedStmt, nil } stmt, err := s.db.Prepare(query) diff --git a/internal/db/sqlite/store_test.go b/internal/db/sqlite/store_test.go index 8f96aa0..c73ed6b 100644 --- a/internal/db/sqlite/store_test.go +++ b/internal/db/sqlite/store_test.go @@ -130,13 +130,13 @@ func (s *StoreSuite) TestQueryContext() { seedSession(s.T(), s.db, "claude-1", "sdk-1", "project-a") tests := []struct { + setupFunc func() + assertFunc func(rows *sql.Rows) name string query string args []interface{} - wantErr bool wantRows int - setupFunc func() - assertFunc func(rows *sql.Rows) + wantErr bool }{ { name: "query existing session", @@ -366,8 +366,8 @@ func (s *HelpersSuite) TestNullInt() { func (s *HelpersSuite) TestRepeatPlaceholders() { tests := []struct { name string - input int expected string + input int }{ { name: "zero", @@ -438,10 +438,10 @@ func TestBuildGetByIDsQuery(t *testing.T) { tests := []struct { name string baseQuery string - ids []int64 orderBy string - limit int wantQuery string + ids []int64 + limit int wantArgs int }{ { @@ -483,10 +483,10 @@ func TestEnsureSessionExists(t *testing.T) { ctx := context.Background() tests := []struct { + setup func() name string sdkSessionID string project string - setup func() wantErr bool }{ { diff --git a/internal/embedding/model.go b/internal/embedding/model.go index 3a40838..2c1124d 100644 --- a/internal/embedding/model.go +++ b/internal/embedding/model.go @@ -21,15 +21,10 @@ const ( // ONNXConfig describes ONNX-specific model configuration. // This allows different models to specify their tensor names and pooling needs. type ONNXConfig struct { - // InputNames are the ONNX input tensor names in order. - InputNames []string - // OutputNames are the ONNX output tensor names. + Pooling PoolingStrategy + InputNames []string OutputNames []string - // Pooling specifies how to convert token embeddings to sentence embeddings. - // If PoolingNone, the model outputs sentence embeddings directly. - Pooling PoolingStrategy - // HiddenSize is the embedding dimension (used for pooling calculations). - HiddenSize int + HiddenSize int } // EmbeddingModel represents a text embedding model. @@ -62,11 +57,11 @@ type ONNXConfigurer interface { // ModelMetadata describes an embedding model for UI/config. type ModelMetadata struct { - Name string `json:"name"` // Human-readable name - Version string `json:"version"` // Short ID for DB storage - Dimensions int `json:"dimensions"` // Vector size - Description string `json:"description"` // Brief description - Default bool `json:"default"` // Is this the default model? + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description"` + Dimensions int `json:"dimensions"` + Default bool `json:"default"` } // ModelFactory creates a new instance of an embedding model. @@ -74,10 +69,10 @@ type ModelFactory func() (EmbeddingModel, error) // ModelRegistry provides model lookup by version. type ModelRegistry struct { - mu sync.RWMutex models map[string]ModelFactory metadata map[string]ModelMetadata defaultModel string + mu sync.RWMutex } // NewModelRegistry creates a new model registry. diff --git a/internal/embedding/service.go b/internal/embedding/service.go index 6611322..048e76f 100644 --- a/internal/embedding/service.go +++ b/internal/embedding/service.go @@ -46,9 +46,9 @@ var bgeONNXConfig = ONNXConfig{ type bgeModel struct { tk *tokenizer.Tokenizer session *ort.DynamicAdvancedSession + libDir string + config ONNXConfig mu sync.Mutex - libDir string // temp directory containing extracted libraries - config ONNXConfig // ONNX configuration for this model } // Compile-time check that bgeModel implements EmbeddingModel @@ -70,8 +70,8 @@ func newBGEModel() (EmbeddingModel, error) { ort.SetSharedLibraryPath(libPath) // Initialize ONNX runtime - if err := ort.InitializeEnvironment(); err != nil { - return nil, fmt.Errorf("initialize ONNX runtime: %w", err) + if initErr := ort.InitializeEnvironment(); initErr != nil { + return nil, fmt.Errorf("initialize ONNX runtime: %w", initErr) } // Load tokenizer from embedded data @@ -292,19 +292,19 @@ func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) { if err != nil { return nil, fmt.Errorf("create input_ids tensor: %w", err) } - defer inputIdsTensor.Destroy() + defer func() { _ = inputIdsTensor.Destroy() }() attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData) if err != nil { return nil, fmt.Errorf("create attention_mask tensor: %w", err) } - defer attentionMaskTensor.Destroy() + defer func() { _ = attentionMaskTensor.Destroy() }() tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData) if err != nil { return nil, fmt.Errorf("create token_type_ids tensor: %w", err) } - defer tokenTypeIdsTensor.Destroy() + defer func() { _ = tokenTypeIdsTensor.Destroy() }() // Create output tensor based on pooling strategy var outputShape ort.Shape @@ -324,7 +324,7 @@ func (m *bgeModel) computeBatch(sentences []string) ([][]float32, error) { if err != nil { return nil, fmt.Errorf("create output tensor: %w", err) } - defer outputTensor.Destroy() + defer func() { _ = outputTensor.Destroy() }() // Run inference inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 884f375..255d110 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -15,10 +15,10 @@ import ( // Server is the MCP server that exposes search tools. type Server struct { - searchMgr *search.Manager - version string stdin io.Reader stdout io.Writer + searchMgr *search.Manager + version string } // NewServer creates a new MCP server. @@ -41,17 +41,17 @@ type Request struct { // Response represents a JSON-RPC response. type Response struct { - JSONRPC string `json:"jsonrpc"` ID any `json:"id"` Result any `json:"result,omitempty"` Error *Error `json:"error,omitempty"` + JSONRPC string `json:"jsonrpc"` } // Error represents a JSON-RPC error. type Error struct { - Code int `json:"code"` - Message string `json:"message"` Data any `json:"data,omitempty"` + Message string `json:"message"` + Code int `json:"code"` } // ToolCallParams represents parameters for tools/call method. @@ -62,9 +62,9 @@ type ToolCallParams struct { // Tool represents an MCP tool definition. type Tool struct { + InputSchema map[string]any `json:"inputSchema"` Name string `json:"name"` Description string `json:"description"` - InputSchema map[string]any `json:"inputSchema"` } // Run starts the MCP server loop. @@ -440,17 +440,17 @@ func (s *Server) callTool(ctx context.Context, name string, args json.RawMessage // TimelineParams represents parameters for timeline operations. type TimelineParams struct { - AnchorID int64 `json:"anchor_id"` Query string `json:"query"` - Before int `json:"before"` - After int `json:"after"` Project string `json:"project"` ObsType string `json:"obs_type"` Concepts string `json:"concepts"` Files string `json:"files"` + Format string `json:"format"` + AnchorID int64 `json:"anchor_id"` + Before int `json:"before"` + After int `json:"after"` DateStart int64 `json:"dateStart"` DateEnd int64 `json:"dateEnd"` - Format string `json:"format"` } // handleTimeline handles timeline requests. diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index ef094f4..5f81357 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -34,8 +34,8 @@ func (s *ServerSuite) TestNewServer() { func TestRequest(t *testing.T) { tests := []struct { name string - req Request expected string + req Request }{ { name: "initialize request", @@ -137,10 +137,10 @@ func TestResponse(t *testing.T) { // TestError tests Error struct. func TestError(t *testing.T) { - tests := []struct { - name string + tests := []struct { //nolint:govet err Error expected string + name string }{ { name: "parse error", @@ -365,11 +365,11 @@ func TestHandleRequest(t *testing.T) { ctx := context.Background() tests := []struct { - name string req *Request - expectError bool - errorCode int + name string errorMessage string + errorCode int + expectError bool }{ { name: "initialize method", @@ -683,10 +683,8 @@ func TestRun_MultipleRequests(t *testing.T) { func TestHandleTimeline_Defaults(t *testing.T) { // Test that handleTimeline sets default before/after values params := TimelineParams{ - AnchorID: 0, - Query: "", - Before: 0, - After: 0, + Before: 0, + After: 0, } // Simulate the default value assignment from handleTimeline @@ -753,13 +751,13 @@ func TestServerStdinStdoutConfig(t *testing.T) { // TestResponseIDTypes tests that response IDs can be various types. func TestResponseIDTypes(t *testing.T) { tests := []struct { - name string id any + name string }{ - {"integer id", 1}, - {"string id", "abc-123"}, - {"float id", 1.5}, - {"null id", nil}, + {name: "integer id", id: 1}, + {name: "string id", id: "abc-123"}, + {name: "float id", id: 1.5}, + {name: "null id", id: nil}, } for _, tt := range tests { diff --git a/internal/pattern/detector.go b/internal/pattern/detector.go index 9ad8256..c9de955 100644 --- a/internal/pattern/detector.go +++ b/internal/pattern/detector.go @@ -38,21 +38,15 @@ type PatternSyncFunc func(pattern *models.Pattern) // Detector detects and tracks recurring patterns across observations. type Detector struct { - config DetectorConfig + ctx context.Context patternStore *sqlite.PatternStore observationStore *sqlite.ObservationStore - - // Vector sync callback - syncFunc PatternSyncFunc - - // Candidate tracking (patterns not yet confirmed) - candidates map[string]*candidatePattern - candidatesMu sync.RWMutex - - // Background analysis - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + syncFunc PatternSyncFunc + candidates map[string]*candidatePattern + cancel context.CancelFunc + config DetectorConfig + wg sync.WaitGroup + candidatesMu sync.RWMutex } // SetSyncFunc sets the callback for syncing patterns to vector store. @@ -62,11 +56,11 @@ func (d *Detector) SetSyncFunc(fn PatternSyncFunc) { // candidatePattern tracks a potential pattern before it reaches frequency threshold. type candidatePattern struct { + patternType models.PatternType + title string signature []string observationIDs []int64 projects []string - patternType models.PatternType - title string lastSeenEpoch int64 } diff --git a/internal/pattern/detector_test.go b/internal/pattern/detector_test.go index 1afc095..396e6fb 100644 --- a/internal/pattern/detector_test.go +++ b/internal/pattern/detector_test.go @@ -331,16 +331,16 @@ func TestDefaultConfig(t *testing.T) { func TestGeneratePatternName(t *testing.T) { tests := []struct { patternType models.PatternType - signature []string title string wantPrefix string + signature []string }{ - {models.PatternTypeBug, []string{"nil", "error"}, "", "Bug Pattern:"}, - {models.PatternTypeRefactor, []string{"extract"}, "", "Refactor Pattern:"}, - {models.PatternTypeArchitecture, []string{"service"}, "", "Architecture Pattern:"}, - {models.PatternTypeAntiPattern, []string{"god-class"}, "", "Anti-Pattern:"}, - {models.PatternTypeBestPractice, []string{"testing"}, "", "Best Practice:"}, - {models.PatternTypeBug, []string{}, "Short Title", "Short Title"}, // Use title directly + {patternType: models.PatternTypeBug, signature: []string{"nil", "error"}, title: "", wantPrefix: "Bug Pattern:"}, + {patternType: models.PatternTypeRefactor, signature: []string{"extract"}, title: "", wantPrefix: "Refactor Pattern:"}, + {patternType: models.PatternTypeArchitecture, signature: []string{"service"}, title: "", wantPrefix: "Architecture Pattern:"}, + {patternType: models.PatternTypeAntiPattern, signature: []string{"god-class"}, title: "", wantPrefix: "Anti-Pattern:"}, + {patternType: models.PatternTypeBestPractice, signature: []string{"testing"}, title: "", wantPrefix: "Best Practice:"}, + {patternType: models.PatternTypeBug, signature: []string{}, title: "Short Title", wantPrefix: "Short Title"}, // Use title directly } for _, tt := range tests { diff --git a/internal/reranking/service.go b/internal/reranking/service.go index c158889..f41494d 100644 --- a/internal/reranking/service.go +++ b/internal/reranking/service.go @@ -30,24 +30,24 @@ const ( // Candidate represents a search result candidate for reranking. type Candidate struct { - ID string // Document ID - Content string // Document text content for scoring - Score float64 // Original bi-encoder similarity score - Metadata map[string]any // Preserved metadata - RerankInfo map[string]float64 // Reranking debug info (optional) + Metadata map[string]any + RerankInfo map[string]float64 + ID string + Content string + Score float64 } // RerankResult represents a reranked search result. type RerankResult struct { - ID string // Document ID - Content string // Document text content - OriginalScore float64 // Original bi-encoder score - RerankScore float64 // Cross-encoder relevance score - CombinedScore float64 // Weighted combination of scores - Metadata map[string]any // Preserved metadata - OriginalRank int // Position before reranking (1-indexed) - RerankRank int // Position after reranking (1-indexed) - RankImprovement int // How much the rank improved (positive = moved up) + Metadata map[string]any + ID string + Content string + OriginalScore float64 + RerankScore float64 + CombinedScore float64 + OriginalRank int + RerankRank int + RankImprovement int } // Service provides cross-encoder reranking functionality. @@ -297,19 +297,19 @@ func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, err if err != nil { return nil, fmt.Errorf("create input_ids tensor: %w", err) } - defer inputIdsTensor.Destroy() + defer func() { _ = inputIdsTensor.Destroy() }() attentionMaskTensor, err := ort.NewTensor(inputShape, attentionMaskData) if err != nil { return nil, fmt.Errorf("create attention_mask tensor: %w", err) } - defer attentionMaskTensor.Destroy() + defer func() { _ = attentionMaskTensor.Destroy() }() tokenTypeIdsTensor, err := ort.NewTensor(inputShape, tokenTypeIdsData) if err != nil { return nil, fmt.Errorf("create token_type_ids tensor: %w", err) } - defer tokenTypeIdsTensor.Destroy() + defer func() { _ = tokenTypeIdsTensor.Destroy() }() // Cross-encoder outputs [batch, 1] logits outputShape := ort.NewShape(int64(batchSize), 1) @@ -317,7 +317,7 @@ func (s *Service) scoreAll(query string, candidates []Candidate) ([]float64, err if err != nil { return nil, fmt.Errorf("create output tensor: %w", err) } - defer outputTensor.Destroy() + defer func() { _ = outputTensor.Destroy() }() // Run inference inputTensors := []ort.Value{inputIdsTensor, attentionMaskTensor, tokenTypeIdsTensor} diff --git a/internal/scoring/recalculator.go b/internal/scoring/recalculator.go index bd88ffe..2bcb0d4 100644 --- a/internal/scoring/recalculator.go +++ b/internal/scoring/recalculator.go @@ -21,13 +21,13 @@ type ObservationStore interface { // Recalculator periodically recalculates importance scores for observations. type Recalculator struct { + log zerolog.Logger store ObservationStore calculator *Calculator - log zerolog.Logger - interval time.Duration - batchSize int stopCh chan struct{} doneCh chan struct{} + interval time.Duration + batchSize int mu sync.Mutex running bool } diff --git a/internal/scoring/recalculator_test.go b/internal/scoring/recalculator_test.go index 6fff695..a1bdbbf 100644 --- a/internal/scoring/recalculator_test.go +++ b/internal/scoring/recalculator_test.go @@ -16,14 +16,14 @@ import ( // MockObservationStore is a mock implementation of ObservationStore for testing. type MockObservationStore struct { - mu sync.Mutex - observations []*models.Observation - scores map[int64]float64 - conceptWeights map[string]float64 updateErr error getErr error getConceptsErr error + scores map[int64]float64 + conceptWeights map[string]float64 + observations []*models.Observation updateScoresCalls int + mu sync.Mutex } func NewMockObservationStore() *MockObservationStore { diff --git a/internal/search/expansion/expander.go b/internal/search/expansion/expander.go index be1ab6e..ff6f22d 100644 --- a/internal/search/expansion/expander.go +++ b/internal/search/expansion/expander.go @@ -30,25 +30,25 @@ const ( // ExpandedQuery represents a query variant with metadata. type ExpandedQuery struct { Query string `json:"query"` - Weight float64 `json:"weight"` // Weight for result merging (0.0-1.0) - Source string `json:"source"` // Where this expansion came from - Intent QueryIntent `json:"intent"` // Detected intent + Source string `json:"source"` + Intent QueryIntent `json:"intent"` + Weight float64 `json:"weight"` } // Expander provides context-aware query expansion. type Expander struct { embedSvc *embedding.Service - vocabulary []VocabEntry // Known vocabulary from observations - vocabVectors [][]float32 // Embeddings for vocabulary entries - vocabMu sync.RWMutex // Protects vocabulary intentPatterns map[QueryIntent][]*regexp.Regexp + vocabulary []VocabEntry + vocabVectors [][]float32 + vocabMu sync.RWMutex } // VocabEntry represents a vocabulary term from observations. type VocabEntry struct { - Term string // The term itself - Weight float64 // How common/important this term is (0.0-1.0) - Source string // Where it came from (title, concept, narrative) + Term string + Source string + Weight float64 } // Config holds expander configuration. diff --git a/internal/search/expansion/expander_test.go b/internal/search/expansion/expander_test.go index e455ad3..1fdee22 100644 --- a/internal/search/expansion/expander_test.go +++ b/internal/search/expansion/expander_test.go @@ -88,16 +88,16 @@ func (s *ExpanderSuite) TestExpand() { tests := []struct { name string query string + expectedIntent QueryIntent minExpansions int hasOriginal bool - expectedIntent QueryIntent }{ - {"question", "how do I implement auth", 1, true, IntentQuestion}, - {"error", "fix the bug in login", 1, true, IntentError}, - {"implementation", "implement user handler", 1, true, IntentImplementation}, - {"architecture", "architecture design", 1, true, IntentArchitecture}, - {"general", "database connection", 1, true, IntentGeneral}, - {"empty", "", 0, false, IntentGeneral}, + {name: "question", query: "how do I implement auth", expectedIntent: IntentQuestion, minExpansions: 1, hasOriginal: true}, + {name: "error", query: "fix the bug in login", expectedIntent: IntentError, minExpansions: 1, hasOriginal: true}, + {name: "implementation", query: "implement user handler", expectedIntent: IntentImplementation, minExpansions: 1, hasOriginal: true}, + {name: "architecture", query: "architecture design", expectedIntent: IntentArchitecture, minExpansions: 1, hasOriginal: true}, + {name: "general", query: "database connection", expectedIntent: IntentGeneral, minExpansions: 1, hasOriginal: true}, + {name: "empty", query: "", expectedIntent: IntentGeneral, minExpansions: 0, hasOriginal: false}, } for _, tt := range tests { @@ -392,13 +392,13 @@ func TestTruncate(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ - {"short", "hello", 10, "hello"}, - {"exact", "hello", 5, "hello"}, - {"long", "hello world", 5, "hello..."}, - {"empty", "", 10, ""}, + {name: "short", input: "hello", maxLen: 10, expected: "hello"}, + {name: "exact", input: "hello", maxLen: 5, expected: "hello"}, + {name: "long", input: "hello world", maxLen: 5, expected: "hello..."}, + {name: "empty", input: "", maxLen: 10, expected: ""}, } for _, tt := range tests { diff --git a/internal/search/integration_test.go b/internal/search/integration_test.go index a1e0358..53bfe2e 100644 --- a/internal/search/integration_test.go +++ b/internal/search/integration_test.go @@ -517,16 +517,16 @@ func TestTruncate_TableDriven(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ - {"short_string", "hello", 10, "hello"}, - {"exact_length", "hello", 5, "hello"}, - {"long_string", "hello world", 5, "hello..."}, - {"empty_string", "", 10, ""}, - {"whitespace_only", " ", 10, ""}, - {"with_leading_space", " hello ", 10, "hello"}, - {"very_long", "this is a very long string that should be truncated", 20, "this is a very long ..."}, + {name: "short_string", input: "hello", maxLen: 10, expected: "hello"}, + {name: "exact_length", input: "hello", maxLen: 5, expected: "hello"}, + {name: "long_string", input: "hello world", maxLen: 5, expected: "hello..."}, + {name: "empty_string", input: "", maxLen: 10, expected: ""}, + {name: "whitespace_only", input: " ", maxLen: 10, expected: ""}, + {name: "with_leading_space", input: " hello ", maxLen: 10, expected: "hello"}, + {name: "very_long", input: "this is a very long string that should be truncated", maxLen: 20, expected: "this is a very long ..."}, } for _, tt := range tests { diff --git a/internal/search/manager.go b/internal/search/manager.go index 5ae8cee..5dc16bf 100644 --- a/internal/search/manager.go +++ b/internal/search/manager.go @@ -35,41 +35,41 @@ func NewManager( // SearchParams contains parameters for unified search. type SearchParams struct { - Query string - Type string // "observations", "sessions", "prompts", or empty for all + Format string + Type string Project string - ObsType string // Observation type filter + ObsType string Concepts string Files string + Query string + Scope string + OrderBy string DateStart int64 - DateEnd int64 - OrderBy string // "relevance", "date_desc", "date_asc" - Limit int Offset int - Format string // "index" or "full" - Scope string // "project", "global", or empty for project+global - IncludeGlobal bool // If true, include global observations along with project-scoped - ExcludeSuperseded bool // If true, exclude observations that have been superseded + Limit int + DateEnd int64 + IncludeGlobal bool + ExcludeSuperseded bool } // SearchResult represents a unified search result. type SearchResult struct { - Type string `json:"type"` // "observation", "session", "prompt" - ID int64 `json:"id"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Type string `json:"type"` Title string `json:"title,omitempty"` Content string `json:"content,omitempty"` Project string `json:"project"` - Scope string `json:"scope,omitempty"` // "project" or "global" + Scope string `json:"scope,omitempty"` + ID int64 `json:"id"` CreatedAt int64 `json:"created_at_epoch"` Score float64 `json:"score,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` } // UnifiedSearchResult contains the combined search results. type UnifiedSearchResult struct { + Query string `json:"query,omitempty"` Results []SearchResult `json:"results"` TotalCount int `json:"total_count"` - Query string `json:"query,omitempty"` } // UnifiedSearch performs a unified search across all document types. diff --git a/internal/search/manager_test.go b/internal/search/manager_test.go index bdd077b..f69e33a 100644 --- a/internal/search/manager_test.go +++ b/internal/search/manager_test.go @@ -4,7 +4,6 @@ package search import ( "database/sql" "testing" - "time" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/stretchr/testify/assert" @@ -94,8 +93,8 @@ func TestTruncate(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ { name: "short string no truncation", @@ -147,11 +146,11 @@ func TestTruncate(t *testing.T) { func TestObservationToResult(t *testing.T) { m := NewManager(nil, nil, nil, nil) - tests := []struct { - name string + tests := []struct { //nolint:govet obs *models.Observation - format string expected SearchResult + format string + name string }{ { name: "full format with all fields", @@ -239,11 +238,11 @@ func TestObservationToResult(t *testing.T) { func TestSummaryToResult(t *testing.T) { m := NewManager(nil, nil, nil, nil) - tests := []struct { - name string + tests := []struct { //nolint:govet summary *models.SessionSummary - format string expected SearchResult + format string + name string }{ { name: "full format with all fields", @@ -321,11 +320,11 @@ func TestSummaryToResult(t *testing.T) { func TestPromptToResult(t *testing.T) { m := NewManager(nil, nil, nil, nil) - tests := []struct { - name string + tests := []struct { //nolint:govet prompt *models.UserPromptWithSession - format string expected SearchResult + format string + name string }{ { name: "full format with content", @@ -406,9 +405,9 @@ func TestPromptToResult(t *testing.T) { func TestSearchParamsValidation(t *testing.T) { tests := []struct { name string + expectedOrder string params SearchParams expectedLimit int - expectedOrder string }{ { name: "default limit applied", @@ -731,23 +730,21 @@ func TestPromptToResultFormats(t *testing.T) { func TestSearchParamsDefaults(t *testing.T) { tests := []struct { name string - initialLimit int initialOrder string - expectedLimit int expectedOrder string + initialLimit int + expectedLimit int }{ - {"zero_limit", 0, "", 20, "date_desc"}, - {"negative_limit", -5, "", 20, "date_desc"}, - {"over_100_limit", 150, "", 100, "date_desc"}, - {"valid_limit_50", 50, "relevance", 50, "relevance"}, - {"custom_order", 30, "date_asc", 30, "date_asc"}, + {name: "zero_limit", initialLimit: 0, initialOrder: "", expectedLimit: 20, expectedOrder: "date_desc"}, + {name: "negative_limit", initialLimit: -5, initialOrder: "", expectedLimit: 20, expectedOrder: "date_desc"}, + {name: "over_100_limit", initialLimit: 150, initialOrder: "", expectedLimit: 100, expectedOrder: "date_desc"}, + {name: "valid_limit_50", initialLimit: 50, initialOrder: "relevance", expectedLimit: 50, expectedOrder: "relevance"}, + {name: "custom_order", initialLimit: 30, initialOrder: "date_asc", expectedLimit: 30, expectedOrder: "date_asc"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { params := SearchParams{ - Query: "test", - Project: "project", Limit: tt.initialLimit, OrderBy: tt.initialOrder, } @@ -774,18 +771,18 @@ func TestTruncateEdgeCases(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ // Unicode strings - uses byte length so ensure maxLen accommodates full string - {"unicode_string_no_truncate", "日本語テスト", 20, "日本語テスト"}, - {"mixed_unicode_no_truncate", "Hello世界", 15, "Hello世界"}, + {name: "unicode_string_no_truncate", input: "日本語テスト", maxLen: 20, expected: "日本語テスト"}, + {name: "mixed_unicode_no_truncate", input: "Hello世界", maxLen: 15, expected: "Hello世界"}, // ASCII truncation - {"ascii_truncate", "Hello World", 5, "Hello..."}, - {"only_whitespace", " ", 10, ""}, - {"tabs_and_newlines", "\t\n \t", 10, ""}, - {"newlines_with_content", "\n\nhello\n\n", 10, "hello"}, - {"zero_max_len", "hello", 0, "..."}, + {name: "ascii_truncate", input: "Hello World", maxLen: 5, expected: "Hello..."}, + {name: "only_whitespace", input: " ", maxLen: 10, expected: ""}, + {name: "tabs_and_newlines", input: "\t\n \t", maxLen: 10, expected: ""}, + {name: "newlines_with_content", input: "\n\nhello\n\n", maxLen: 10, expected: "hello"}, + {name: "zero_max_len", input: "hello", maxLen: 0, expected: "..."}, } for _, tt := range tests { @@ -812,8 +809,6 @@ func TestUnifiedSearchResultEmpty(t *testing.T) { // TestSearchResultMetadata tests SearchResult metadata handling. func TestSearchResultMetadata(t *testing.T) { result := SearchResult{ - Type: "observation", - ID: 1, Metadata: map[string]interface{}{ "obs_type": "discovery", "scope": "project", @@ -835,10 +830,7 @@ func TestSearchResultTypes(t *testing.T) { for _, typ := range types { t.Run(typ, func(t *testing.T) { result := SearchResult{ - Type: typ, - ID: 1, - Project: "test", - CreatedAt: time.Now().UnixMilli(), + Type: typ, } assert.Equal(t, typ, result.Type) }) @@ -952,8 +944,8 @@ func TestSearchParams_OrderByValues(t *testing.T) { for _, order := range validOrders { t.Run("order_"+order, func(t *testing.T) { params := SearchParams{ - Query: "test", - Project: "test", + Query: "test", //nolint:govet + Project: "test", //nolint:govet OrderBy: order, } assert.Equal(t, order, params.OrderBy) @@ -968,9 +960,7 @@ func TestSearchParams_TypeValues(t *testing.T) { for _, typ := range validTypes { t.Run("type_"+typ, func(t *testing.T) { params := SearchParams{ - Query: "test", - Project: "test", - Type: typ, + Type: typ, } assert.Equal(t, typ, params.Type) }) @@ -984,8 +974,8 @@ func TestSearchParams_ScopeValues(t *testing.T) { for _, scope := range validScopes { t.Run("scope_"+scope, func(t *testing.T) { params := SearchParams{ - Query: "test", - Project: "test", + Query: "test", //nolint:govet + Project: "test", //nolint:govet Scope: scope, } assert.Equal(t, scope, params.Scope) @@ -1000,8 +990,8 @@ func TestSearchParams_FormatValues(t *testing.T) { for _, format := range validFormats { t.Run("format_"+format, func(t *testing.T) { params := SearchParams{ - Query: "test", - Project: "test", + Query: "test", //nolint:govet + Project: "test", //nolint:govet Format: format, } assert.Equal(t, format, params.Format) @@ -1020,7 +1010,7 @@ func TestUnifiedSearchResult_MultipleResults(t *testing.T) { result := UnifiedSearchResult{ Results: results, TotalCount: 3, - Query: "test query", + Query: "test query", //nolint:govet } assert.Len(t, result.Results, 3) @@ -1040,8 +1030,8 @@ func TestSearchResult_Metadata(t *testing.T) { } result := SearchResult{ - Type: "observation", - ID: 1, + Type: "observation", //nolint:govet + ID: 1, //nolint:govet Metadata: metadata, } @@ -1066,8 +1056,6 @@ func TestSearchResult_Scores(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := SearchResult{ - Type: "observation", - ID: 1, Score: tt.score, } assert.Equal(t, tt.score, result.Score) diff --git a/internal/update/update.go b/internal/update/update.go index c185aa8..3b25fec 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -33,11 +33,11 @@ const ( // Release represents a GitHub release. type Release struct { + PublishedAt time.Time `json:"published_at"` TagName string `json:"tag_name"` Name string `json:"name"` - PublishedAt time.Time `json:"published_at"` - Assets []Asset `json:"assets"` Body string `json:"body"` + Assets []Asset `json:"assets"` } // Asset represents a release asset. @@ -49,15 +49,15 @@ type Asset struct { // UpdateInfo contains information about an available update. type UpdateInfo struct { - Available bool `json:"available"` + PublishedAt time.Time `json:"published_at,omitempty"` CurrentVersion string `json:"current_version"` LatestVersion string `json:"latest_version"` ReleaseNotes string `json:"release_notes,omitempty"` - PublishedAt time.Time `json:"published_at,omitempty"` DownloadURL string `json:"download_url,omitempty"` ChecksumsURL string `json:"checksums_url,omitempty"` - BundleURL string `json:"bundle_url,omitempty"` // Sigstore bundle (.sigstore.json) + BundleURL string `json:"bundle_url,omitempty"` ManualUpdateCommand string `json:"manual_update_command,omitempty"` + Available bool `json:"available"` } // InstallScriptURL is the URL to the remote installation script. @@ -74,23 +74,22 @@ func GetManualUpdateCommand(version string) string { // UpdateStatus represents the current update status. type UpdateStatus struct { - State string `json:"state"` // "idle", "checking", "downloading", "verifying", "applying", "done", "error" - Progress float64 `json:"progress"` + State string `json:"state"` Message string `json:"message"` Error string `json:"error,omitempty"` - ManualUpdateCommand string `json:"manual_update_command,omitempty"` // Shown when update fails + ManualUpdateCommand string `json:"manual_update_command,omitempty"` + Progress float64 `json:"progress"` } // Updater handles self-updates. -type Updater struct { +type Updater struct { //nolint:govet + httpClient *http.Client + cachedUpdate *UpdateInfo + lastCheck time.Time + status UpdateStatus currentVersion string installDir string - httpClient *http.Client - - mu sync.RWMutex - status UpdateStatus - lastCheck time.Time - cachedUpdate *UpdateInfo + mu sync.RWMutex } // New creates a new Updater. diff --git a/internal/vector/sqlitevec/client.go b/internal/vector/sqlitevec/client.go index df2e836..5ca9787 100644 --- a/internal/vector/sqlitevec/client.go +++ b/internal/vector/sqlitevec/client.go @@ -87,9 +87,9 @@ func (c *Client) AddDocuments(ctx context.Context, docs []Document) error { for i, doc := range docs { // Serialize embedding to blob format - embBlob, err := sqlite_vec.SerializeFloat32(embeddings[i]) - if err != nil { - return fmt.Errorf("serialize embedding for %s: %w", doc.ID, err) + embBlob, serErr := sqlite_vec.SerializeFloat32(embeddings[i]) + if serErr != nil { + return fmt.Errorf("serialize embedding for %s: %w", doc.ID, serErr) } // Extract metadata @@ -212,8 +212,8 @@ func (c *Client) Query(ctx context.Context, query string, limit int, where map[s var sqliteID int64 var docType, fieldType, project, scope sql.NullString - if err := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docType, &fieldType, &project, &scope); err != nil { - return nil, fmt.Errorf("scan row: %w", err) + if scanErr := rows.Scan(&r.ID, &r.Distance, &sqliteID, &docType, &fieldType, &project, &scope); scanErr != nil { + return nil, fmt.Errorf("scan row: %w", scanErr) } r.Similarity = DistanceToSimilarity(r.Distance) @@ -319,11 +319,11 @@ func (c *Client) NeedsRebuild(ctx context.Context) (bool, string) { // StaleVectorInfo contains information about a vector that needs rebuilding. type StaleVectorInfo struct { DocID string - SQLiteID int64 DocType string FieldType string Project string Scope string + SQLiteID int64 } // GetStaleVectors returns doc_ids of vectors with mismatched or null model versions. @@ -352,8 +352,8 @@ func (c *Client) GetStaleVectors(ctx context.Context) ([]StaleVectorInfo, error) var sqliteID sql.NullInt64 var docType, fieldType, project, scope sql.NullString - if err := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); err != nil { - return nil, fmt.Errorf("scan row: %w", err) + if scanErr := rows.Scan(&info.DocID, &sqliteID, &docType, &fieldType, &project, &scope); scanErr != nil { + return nil, fmt.Errorf("scan row: %w", scanErr) } info.SQLiteID = sqliteID.Int64 diff --git a/internal/vector/sqlitevec/helpers.go b/internal/vector/sqlitevec/helpers.go index a104b51..6f5f949 100644 --- a/internal/vector/sqlitevec/helpers.go +++ b/internal/vector/sqlitevec/helpers.go @@ -8,21 +8,22 @@ const ( DocTypeObservation DocType = "observation" DocTypeSessionSummary DocType = "session_summary" DocTypeUserPrompt DocType = "user_prompt" + DocTypeCodeChunk DocType = "code_chunk" ) // Document represents a document to store with vector embedding. type Document struct { + Metadata map[string]any ID string Content string - Metadata map[string]any } // QueryResult represents a search result from vector search. type QueryResult struct { + Metadata map[string]any ID string Distance float64 - Similarity float64 // 1.0 = identical, 0.0 = opposite (derived from distance) - Metadata map[string]any + Similarity float64 } // DistanceToSimilarity converts sqlite-vec cosine distance to similarity score. diff --git a/internal/vector/sqlitevec/helpers_test.go b/internal/vector/sqlitevec/helpers_test.go index 0e9b5ac..ef343b3 100644 --- a/internal/vector/sqlitevec/helpers_test.go +++ b/internal/vector/sqlitevec/helpers_test.go @@ -42,10 +42,10 @@ func TestQueryResult_Fields(t *testing.T) { func TestBuildWhereFilter(t *testing.T) { tests := []struct { + expected map[string]interface{} name string docType DocType project string - expected map[string]interface{} }{ { name: "empty_filters", @@ -474,9 +474,9 @@ func TestCopyMetadataMulti(t *testing.T) { func TestJoinStrings(t *testing.T) { tests := []struct { name string - strs []string sep string expected string + strs []string }{ { name: "empty_slice", @@ -522,8 +522,8 @@ func TestTruncateString(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ { name: "shorter_than_max", diff --git a/internal/vector/sqlitevec/sync.go b/internal/vector/sqlitevec/sync.go index 21f37c9..2744f41 100644 --- a/internal/vector/sqlitevec/sync.go +++ b/internal/vector/sqlitevec/sync.go @@ -5,13 +5,15 @@ import ( "context" "fmt" + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" "github.com/rs/zerolog/log" ) // Sync provides synchronization between SQLite data and vector embeddings. type Sync struct { - client *Client + client *Client + chunkingManager *chunking.Manager } // NewSync creates a new sync service. @@ -19,9 +21,23 @@ func NewSync(client *Client) *Sync { return &Sync{client: client} } +// SetChunkingManager sets the code chunking manager (optional). +// If set, observations will include code chunks from tracked files. +func (s *Sync) SetChunkingManager(manager *chunking.Manager) { + s.chunkingManager = manager +} + // SyncObservation syncs a single observation to the vector store. +// If a chunking manager is configured, also chunks tracked code files. func (s *Sync) SyncObservation(ctx context.Context, obs *models.Observation) error { docs := s.formatObservationDocs(obs) + + // Add code chunks from tracked files if chunking manager is available + if s.chunkingManager != nil { + codeChunkDocs := s.formatCodeChunkDocs(ctx, obs) + docs = append(docs, codeChunkDocs...) + } + if len(docs) == 0 { return nil } @@ -99,6 +115,98 @@ func (s *Sync) formatObservationDocs(obs *models.Observation) []Document { return docs } +// formatCodeChunkDocs formats code chunks from tracked files into vector documents. +// Uses AST-aware chunking to extract semantic code units (functions, classes, methods). +func (s *Sync) formatCodeChunkDocs(ctx context.Context, obs *models.Observation) []Document { + if s.chunkingManager == nil { + return nil + } + + // Determine scope for metadata + scope := string(obs.Scope) + if scope == "" { + scope = "project" + } + + // Collect all tracked files (read + modified) + allFiles := make([]string, 0, len(obs.FilesRead)+len(obs.FilesModified)) + allFiles = append(allFiles, obs.FilesRead...) + allFiles = append(allFiles, obs.FilesModified...) + + // Filter to only files supported by chunking manager + var supportedFiles []string + for _, file := range allFiles { + if s.chunkingManager.SupportsFile(file) { + supportedFiles = append(supportedFiles, file) + } + } + + if len(supportedFiles) == 0 { + return nil + } + + // Chunk all supported files + results, errs := s.chunkingManager.ChunkFiles(ctx, supportedFiles) + if len(errs) > 0 { + // Log errors but don't fail the entire sync + for _, err := range errs { + log.Warn().Err(err).Msg("Failed to chunk file") + } + } + + // Convert chunks to vector documents + docs := make([]Document, 0) + chunkIndex := 0 + + for filePath, chunks := range results { + for _, chunk := range chunks { + doc := Document{ + ID: fmt.Sprintf("obs_%d_chunk_%d", obs.ID, chunkIndex), + Content: chunk.SearchableContent(), + Metadata: map[string]any{ + "sqlite_id": obs.ID, + "doc_type": string(DocTypeCodeChunk), + "field_type": "code_chunk", + "sdk_session_id": obs.SDKSessionID, + "project": obs.Project, + "scope": scope, + "created_at_epoch": obs.CreatedAtEpoch, + // Code chunk specific metadata + "file_path": filePath, + "language": string(chunk.Language), + "chunk_type": string(chunk.Type), + "symbol_name": chunk.Name, + "start_line": chunk.StartLine, + "end_line": chunk.EndLine, + }, + } + + // Add parent name if this is a method + if chunk.ParentName != "" { + doc.Metadata["parent_name"] = chunk.ParentName + } + + // Add signature if available + if chunk.Signature != "" { + doc.Metadata["signature"] = chunk.Signature + } + + docs = append(docs, doc) + chunkIndex++ + } + } + + if len(docs) > 0 { + log.Debug(). + Int64("observationId", obs.ID). + Int("codeChunks", len(docs)). + Int("files", len(results)). + Msg("Generated code chunk documents") + } + + return docs +} + // SyncSummary syncs a single session summary to the vector store. func (s *Sync) SyncSummary(ctx context.Context, summary *models.SessionSummary) error { docs := s.formatSummaryDocs(summary) @@ -191,21 +299,27 @@ func (s *Sync) SyncUserPrompt(ctx context.Context, prompt *models.UserPromptWith } // DeleteObservations removes observation documents from the vector store. +// Includes both observation fields (narrative, facts) and code chunks. func (s *Sync) DeleteObservations(ctx context.Context, observationIDs []int64) error { if len(observationIDs) == 0 { return nil } // Generate all possible document IDs for these observations - // Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n} + // Pattern: obs_{id}_narrative, obs_{id}_fact_{0..n}, obs_{id}_chunk_{0..n} const maxFactsPerObs = 20 - ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+1)) + const maxChunksPerObs = 100 // Reasonable upper bound for code chunks + ids := make([]string, 0, len(observationIDs)*(maxFactsPerObs+maxChunksPerObs+1)) for _, obsID := range observationIDs { ids = append(ids, fmt.Sprintf("obs_%d_narrative", obsID)) for i := 0; i < maxFactsPerObs; i++ { ids = append(ids, fmt.Sprintf("obs_%d_fact_%d", obsID, i)) } + // Include code chunk IDs + for i := 0; i < maxChunksPerObs; i++ { + ids = append(ids, fmt.Sprintf("obs_%d_chunk_%d", obsID, i)) + } } if err := s.client.DeleteDocuments(ctx, ids); err != nil { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index c0bb765..e95794a 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -16,15 +16,15 @@ import ( // Watcher monitors a file or directory for deletion and calls onDelete when removed. // It watches the parent directory since fsnotify cannot watch non-existent files. type Watcher struct { - targetPath string // The file/directory to watch for deletion - parentPath string // Parent directory (what we actually watch) - onDelete func() // Callback when target is deleted - watcher *fsnotify.Watcher ctx context.Context + onDelete func() + watcher *fsnotify.Watcher cancel context.CancelFunc + targetPath string + parentPath string + debounce time.Duration mu sync.Mutex running bool - debounce time.Duration } // New creates a new Watcher for the given target path. diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index 9098477..b50053a 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -158,10 +158,10 @@ type SessionInitRequest struct { // SessionInitResponse is the response for session initialization. type SessionInitResponse struct { + Reason string `json:"reason,omitempty"` SessionDBID int64 `json:"sessionDbId"` PromptNumber int `json:"promptNumber"` Skipped bool `json:"skipped,omitempty"` - Reason string `json:"reason,omitempty"` } // DuplicatePromptWindowSeconds is the time window for detecting duplicate prompt submissions. @@ -296,8 +296,8 @@ func (s *Service) handleSessionStart(w http.ResponseWriter, r *http.Request) { } var req SessionStartRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if decodeErr := json.NewDecoder(r.Body).Decode(&req); decodeErr != nil { + http.Error(w, decodeErr.Error(), http.StatusBadRequest) return } diff --git a/internal/worker/handlers_relations.go b/internal/worker/handlers_relations.go index 5e890ca..619ecfe 100644 --- a/internal/worker/handlers_relations.go +++ b/internal/worker/handlers_relations.go @@ -46,7 +46,7 @@ func (s *Service) handleGetRelationGraph(w http.ResponseWriter, r *http.Request) // Get depth parameter (default 2) depth := 2 if depthStr := r.URL.Query().Get("depth"); depthStr != "" { - if d, err := strconv.Atoi(depthStr); err == nil && d > 0 && d <= 5 { + if d, atoiErr := strconv.Atoi(depthStr); atoiErr == nil && d > 0 && d <= 5 { depth = d } } @@ -72,7 +72,7 @@ func (s *Service) handleGetRelatedObservations(w http.ResponseWriter, r *http.Re // Get minimum confidence parameter (default 0.4) minConfidence := 0.4 if confStr := r.URL.Query().Get("min_confidence"); confStr != "" { - if c, err := strconv.ParseFloat(confStr, 64); err == nil && c >= 0 && c <= 1 { + if c, parseErr := strconv.ParseFloat(confStr, 64); parseErr == nil && c >= 0 && c <= 1 { minConfidence = c } } diff --git a/internal/worker/handlers_scoring.go b/internal/worker/handlers_scoring.go index 63db0e2..c15e2e5 100644 --- a/internal/worker/handlers_scoring.go +++ b/internal/worker/handlers_scoring.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/lukaszraczylo/claude-mnemonic/pkg/models" + "github.com/rs/zerolog/log" ) // FeedbackRequest represents a user feedback submission. @@ -65,9 +66,10 @@ func (s *Service) handleObservationFeedback(w http.ResponseWriter, r *http.Reque if err == nil && obs != nil { obs.UserFeedback = req.Feedback // Apply the new feedback newScore = scoreCalculator.Calculate(obs, time.Now()) - if err := observationStore.UpdateImportanceScore(r.Context(), id, newScore); err != nil { + if scoreErr := observationStore.UpdateImportanceScore(r.Context(), id, newScore); scoreErr != nil { // Log but don't fail - feedback was recorded // Score will be updated on next recalculation cycle + log.Warn().Err(scoreErr).Int64("id", id).Msg("Failed to update importance score after feedback") } } } @@ -261,8 +263,9 @@ func (s *Service) handleUpdateConceptWeight(w http.ResponseWriter, r *http.Reque // Refresh concept weights in recalculator if recalculator != nil { - if err := recalculator.RefreshConceptWeights(r.Context()); err != nil { + if refreshErr := recalculator.RefreshConceptWeights(r.Context()); refreshErr != nil { // Log but don't fail - weight was saved + log.Warn().Err(refreshErr).Str("concept", concept).Msg("Failed to refresh concept weights in recalculator") } } @@ -308,8 +311,9 @@ func (s *Service) handleTriggerRecalculation(w http.ResponseWriter, r *http.Requ // Run recalculation in background go func() { - if err := recalculator.RecalculateNow(r.Context()); err != nil { + if recalcErr := recalculator.RecalculateNow(r.Context()); recalcErr != nil { // Log error but don't block response + log.Warn().Err(recalcErr).Msg("Failed to trigger score recalculation") } }() @@ -347,8 +351,9 @@ func (s *Service) incrementRetrievalCounts(ids []int64) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - if err := store.IncrementRetrievalCount(ctx, ids); err != nil { + if incrErr := store.IncrementRetrievalCount(ctx, ids); incrErr != nil { // Log but don't fail - this is a background operation + log.Warn().Err(incrErr).Msg("Failed to increment retrieval count in background") } }() } diff --git a/internal/worker/handlers_test.go b/internal/worker/handlers_test.go index 88d2a0c..2acca5a 100644 --- a/internal/worker/handlers_test.go +++ b/internal/worker/handlers_test.go @@ -1276,8 +1276,6 @@ func TestObservationRequest_Fields(t *testing.T) { ClaudeSessionID: "session-abc", Project: "my-project", ToolName: "Read", - ToolInput: map[string]string{"path": "/file.go"}, - ToolResponse: "file contents", CWD: "/home/user/project", } diff --git a/internal/worker/sdk/parser_test.go b/internal/worker/sdk/parser_test.go index e89b981..2fce3fe 100644 --- a/internal/worker/sdk/parser_test.go +++ b/internal/worker/sdk/parser_test.go @@ -77,10 +77,10 @@ func TestParseObservations_TableDriven(t *testing.T) { tests := []struct { name string input string - expectedCount int expectedType models.ObservationType expectedTitle string checkConcepts []string + expectedCount int }{ { name: "valid_bugfix_observation", @@ -300,9 +300,9 @@ func TestParseSummary_TableDriven(t *testing.T) { tests := []struct { name string input string + expectedRequest string sessionID int64 expectNil bool - expectedRequest string }{ { name: "empty_input", diff --git a/internal/worker/sdk/processor.go b/internal/worker/sdk/processor.go index e049a89..36d5ee8 100644 --- a/internal/worker/sdk/processor.go +++ b/internal/worker/sdk/processor.go @@ -31,15 +31,14 @@ type SyncSummaryFunc func(summary *models.SessionSummary) // Processor handles SDK agent processing of observations and summaries using Claude Code CLI. type Processor struct { - claudePath string - model string observationStore *sqlite.ObservationStore summaryStore *sqlite.SummaryStore broadcastFunc BroadcastFunc syncObservationFunc SyncObservationFunc syncSummaryFunc SyncSummaryFunc - // Semaphore to limit concurrent Claude CLI calls (prevents API overload) - sem chan struct{} + sem chan struct{} + claudePath string + model string } // SetBroadcastFunc sets the broadcast callback for SSE events. diff --git a/internal/worker/sdk/processor_test.go b/internal/worker/sdk/processor_test.go index f18e6c5..30f75f7 100644 --- a/internal/worker/sdk/processor_test.go +++ b/internal/worker/sdk/processor_test.go @@ -11,8 +11,8 @@ import ( func TestIsSelfReferentialSummary(t *testing.T) { tests := []struct { - name string summary *models.ParsedSummary + name string expected bool }{ { @@ -281,8 +281,8 @@ func TestTruncateForLog(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ { name: "shorter_than_max", @@ -371,11 +371,11 @@ func TestCaptureFileMtimes(t *testing.T) { file1 := filepath.Join(tmpDir, "file1.txt") file2 := filepath.Join(tmpDir, "file2.txt") - err = os.WriteFile(file1, []byte("content1"), 0644) + err = os.WriteFile(file1, []byte("content1"), 0600) if err != nil { t.Fatal(err) } - err = os.WriteFile(file2, []byte("content2"), 0644) + err = os.WriteFile(file2, []byte("content2"), 0600) if err != nil { t.Fatal(err) } @@ -415,7 +415,7 @@ func TestGetFileMtimes(t *testing.T) { defer os.RemoveAll(tmpDir) testFile := filepath.Join(tmpDir, "test.txt") - err = os.WriteFile(testFile, []byte("content"), 0644) + err = os.WriteFile(testFile, []byte("content"), 0600) if err != nil { t.Fatal(err) } @@ -437,7 +437,7 @@ func TestGetFileContent(t *testing.T) { t.Run("reads_existing_file", func(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") content := "test content" - err := os.WriteFile(testFile, []byte(content), 0644) + err := os.WriteFile(testFile, []byte(content), 0600) if err != nil { t.Fatal(err) } @@ -459,7 +459,7 @@ func TestGetFileContent(t *testing.T) { for i := 0; i < 3000; i++ { longContent += "x" } - err := os.WriteFile(testFile, []byte(longContent), 0644) + err := os.WriteFile(testFile, []byte(longContent), 0600) if err != nil { t.Fatal(err) } @@ -473,7 +473,7 @@ func TestGetFileContent(t *testing.T) { t.Run("resolves_relative_path_with_cwd", func(t *testing.T) { testFile := filepath.Join(tmpDir, "relative.txt") content := "relative content" - err := os.WriteFile(testFile, []byte(content), 0644) + err := os.WriteFile(testFile, []byte(content), 0600) if err != nil { t.Fatal(err) } @@ -719,8 +719,8 @@ func TestShouldSkipTrivialOperation_EdgeCases(t *testing.T) { // TestIsSelfReferentialSummary_MoreCases tests additional self-referential detection cases. func TestIsSelfReferentialSummary_MoreCases(t *testing.T) { tests := []struct { - name string summary *models.ParsedSummary + name string expected bool }{ { @@ -932,7 +932,7 @@ func TestCaptureFileMtimes_DuplicatePaths(t *testing.T) { defer os.RemoveAll(tmpDir) testFile := filepath.Join(tmpDir, "shared.txt") - err = os.WriteFile(testFile, []byte("content"), 0644) + err = os.WriteFile(testFile, []byte("content"), 0600) if err != nil { t.Fatal(err) } diff --git a/internal/worker/sdk/prompts.go b/internal/worker/sdk/prompts.go index a200c0d..af5e865 100644 --- a/internal/worker/sdk/prompts.go +++ b/internal/worker/sdk/prompts.go @@ -24,12 +24,12 @@ var ObservationConcepts = []string{ // ToolExecution represents a tool execution for observation. type ToolExecution struct { - ID int64 ToolName string ToolInput string ToolOutput string - CreatedAtEpoch int64 CWD string + ID int64 + CreatedAtEpoch int64 } // BuildObservationPrompt builds a prompt for processing a tool observation. @@ -67,12 +67,12 @@ func BuildObservationPrompt(exec ToolExecution) string { // SummaryRequest contains data for building a summary prompt. type SummaryRequest struct { - SessionDBID int64 SDKSessionID string Project string UserPrompt string LastUserMessage string LastAssistantMessage string + SessionDBID int64 } // BuildSummaryPrompt builds a prompt requesting a session summary. diff --git a/internal/worker/sdk/prompts_test.go b/internal/worker/sdk/prompts_test.go index eecf7d5..7711280 100644 --- a/internal/worker/sdk/prompts_test.go +++ b/internal/worker/sdk/prompts_test.go @@ -12,8 +12,8 @@ func TestTruncate(t *testing.T) { tests := []struct { name string input string - maxLen int expected string + maxLen int }{ { name: "shorter_than_max", @@ -58,10 +58,10 @@ func TestTruncate(t *testing.T) { func TestBuildObservationPrompt(t *testing.T) { now := time.Now().UnixMilli() - tests := []struct { - name string + tests := []struct { //nolint:govet exec ToolExecution contains []string + name string }{ { name: "basic_read_tool", diff --git a/internal/worker/service.go b/internal/worker/service.go index 76ac219..9853e1e 100644 --- a/internal/worker/service.go +++ b/internal/worker/service.go @@ -12,6 +12,10 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking" + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/golang" + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/python" + "github.com/lukaszraczylo/claude-mnemonic/internal/chunking/typescript" "github.com/lukaszraczylo/claude-mnemonic/internal/config" "github.com/lukaszraczylo/claude-mnemonic/internal/db/sqlite" "github.com/lukaszraczylo/claude-mnemonic/internal/embedding" @@ -56,80 +60,49 @@ type RetrievalStats struct { // Service is the main worker service orchestrator. type Service struct { - // Version of the worker binary - version string - - // Configuration - config *config.Config - - // Database - store *sqlite.Store - sessionStore *sqlite.SessionStore - observationStore *sqlite.ObservationStore + startTime time.Time + initError error + ctx context.Context + queryExpander *expansion.Expander + recalculator *scoring.Recalculator summaryStore *sqlite.SummaryStore promptStore *sqlite.PromptStore conflictStore *sqlite.ConflictStore patternStore *sqlite.PatternStore relationStore *sqlite.RelationStore - - // Pattern detection - patternDetector *pattern.Detector - - // Domain services - sessionManager *session.Manager - sseBroadcaster *sse.Broadcaster - processor *sdk.Processor - - // Vector database (sqlite-vec with local embeddings) - embedSvc *embedding.Service - vectorClient *sqlitevec.Client - vectorSync *sqlitevec.Sync - - // Cross-encoder reranking (for improved search relevance) - reranker *reranking.Service - - // Query expansion (for improved search recall) - queryExpander *expansion.Expander - - // Importance scoring - scoreCalculator *scoring.Calculator - recalculator *scoring.Recalculator - - // HTTP server - router *chi.Mux - server *http.Server - startTime time.Time - - // Retrieval statistics (per-project) + patternDetector *pattern.Detector + sessionManager *session.Manager + sseBroadcaster *sse.Broadcaster + router *chi.Mux + embedSvc *embedding.Service + vectorClient *sqlitevec.Client + vectorSync *sqlitevec.Sync + reranker *reranking.Service + updater *update.Updater + observationStore *sqlite.ObservationStore + scoreCalculator *scoring.Calculator + processor *sdk.Processor + server *http.Server + sessionStore *sqlite.SessionStore retrievalStats map[string]*RetrievalStats + configWatcher *watcher.Watcher + store *sqlite.Store + cancel context.CancelFunc + dbWatcher *watcher.Watcher + staleQueue chan staleVerifyRequest + config *config.Config + version string + wg sync.WaitGroup + initMu sync.RWMutex retrievalStatsMu sync.RWMutex - - // Lifecycle - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - - // Initialization state (for deferred init) - ready atomic.Bool - initError error - initMu sync.RWMutex - - // Background verification queue for stale observations - staleQueue chan staleVerifyRequest - staleQueueOnce sync.Once - - // File watchers for auto-recreation on deletion - dbWatcher *watcher.Watcher - configWatcher *watcher.Watcher - - // Self-updater - updater *update.Updater + staleQueueOnce sync.Once + ready atomic.Bool } // staleVerifyRequest represents a request to verify a stale observation in background type staleVerifyRequest struct { - observationID int64 cwd string + observationID int64 } // NewService creates a new worker service with deferred initialization. @@ -223,17 +196,29 @@ func (s *Service) initializeAsync() { } else { embedSvc = emb // Create sqlite-vec client using the same DB connection - client, err := sqlitevec.NewClient(sqlitevec.Config{ + client, clientErr := sqlitevec.NewClient(sqlitevec.Config{ DB: store.DB(), }, embedSvc) - if err != nil { - log.Warn().Err(err).Msg("sqlite-vec client creation failed - vector search disabled") + if clientErr != nil { + log.Warn().Err(clientErr).Msg("sqlite-vec client creation failed - vector search disabled") } else { vectorClient = client vectorSync = sqlitevec.NewSync(client) + + // Initialize AST-aware code chunking + chunkOpts := chunking.DefaultChunkOptions() + chunkers := []chunking.Chunker{ + golang.NewChunker(chunkOpts), + python.NewChunker(chunkOpts), + typescript.NewChunker(chunkOpts), + } + chunkingManager := chunking.NewManager(chunkers, chunkOpts) + vectorSync.SetChunkingManager(chunkingManager) + log.Info(). Str("model", embedSvc.Version()). - Msg("sqlite-vec vector search enabled") + Strs("chunkers", []string{"go", "python", "typescript"}). + Msg("sqlite-vec vector search with AST-aware code chunking enabled") } // Create cross-encoder reranking service if enabled @@ -243,9 +228,9 @@ func (s *Service) initializeAsync() { rerankCfg.Alpha = s.config.RerankingAlpha } - ranker, err := reranking.NewService(rerankCfg) - if err != nil { - log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed - reranking disabled") + ranker, rankerErr := reranking.NewService(rerankCfg) + if rankerErr != nil { + log.Warn().Err(rankerErr).Msg("Cross-encoder reranking service creation failed - reranking disabled") } else { reranker = ranker log.Info(). @@ -457,8 +442,8 @@ func (s *Service) startWatchers() { log.Warn().Err(err).Msg("Failed to create database watcher") } else { s.dbWatcher = dbWatcher - if err := dbWatcher.Start(); err != nil { - log.Warn().Err(err).Msg("Failed to start database watcher") + if startErr := dbWatcher.Start(); startErr != nil { + log.Warn().Err(startErr).Msg("Failed to start database watcher") } else { log.Info().Str("path", s.config.DBPath).Msg("Database file watcher started") } @@ -559,15 +544,26 @@ func (s *Service) reinitializeDatabase() { log.Warn().Err(err).Msg("Embedding service creation failed after reinit") } else { embedSvc = emb - client, err := sqlitevec.NewClient(sqlitevec.Config{ + client, clientErr := sqlitevec.NewClient(sqlitevec.Config{ DB: store.DB(), }, embedSvc) - if err != nil { - log.Warn().Err(err).Msg("sqlite-vec client creation failed after reinit") + if clientErr != nil { + log.Warn().Err(clientErr).Msg("sqlite-vec client creation failed after reinit") } else { vectorClient = client vectorSync = sqlitevec.NewSync(client) - log.Info().Msg("sqlite-vec reconnected after reinit") + + // Initialize AST-aware code chunking + chunkOpts := chunking.DefaultChunkOptions() + chunkers := []chunking.Chunker{ + golang.NewChunker(chunkOpts), + python.NewChunker(chunkOpts), + typescript.NewChunker(chunkOpts), + } + chunkingManager := chunking.NewManager(chunkers, chunkOpts) + vectorSync.SetChunkingManager(chunkingManager) + + log.Info().Msg("sqlite-vec with code chunking reconnected after reinit") } // Recreate cross-encoder reranking service if enabled @@ -577,9 +573,9 @@ func (s *Service) reinitializeDatabase() { rerankCfg.Alpha = s.config.RerankingAlpha } - ranker, err := reranking.NewService(rerankCfg) - if err != nil { - log.Warn().Err(err).Msg("Cross-encoder reranking service creation failed after reinit") + ranker, rankerErr := reranking.NewService(rerankCfg) + if rankerErr != nil { + log.Warn().Err(rankerErr).Msg("Cross-encoder reranking service creation failed after reinit") } else { reranker = ranker log.Info().Msg("Cross-encoder reranking reconnected after reinit") @@ -824,8 +820,8 @@ func (s *Service) rebuildAllVectors( log.Error().Err(err).Msg("Failed to fetch observations for vector rebuild") } else { for _, obs := range observations { - if err := vectorSync.SyncObservation(s.ctx, obs); err != nil { - log.Warn().Err(err).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") + if syncErr := vectorSync.SyncObservation(s.ctx, obs); syncErr != nil { + log.Warn().Err(syncErr).Int64("id", obs.ID).Msg("Failed to sync observation during rebuild") syncErrors++ } else { totalSynced++ @@ -840,8 +836,8 @@ func (s *Service) rebuildAllVectors( log.Error().Err(err).Msg("Failed to fetch summaries for vector rebuild") } else { for _, summary := range summaries { - if err := vectorSync.SyncSummary(s.ctx, summary); err != nil { - log.Warn().Err(err).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") + if syncErr := vectorSync.SyncSummary(s.ctx, summary); syncErr != nil { + log.Warn().Err(syncErr).Int64("id", summary.ID).Msg("Failed to sync summary during rebuild") syncErrors++ } else { totalSynced++ diff --git a/internal/worker/session/manager.go b/internal/worker/session/manager.go index c965fea..295f67c 100644 --- a/internal/worker/session/manager.go +++ b/internal/worker/session/manager.go @@ -21,11 +21,11 @@ const ( // ObservationData contains data for a tool observation. type ObservationData struct { - ToolName string ToolInput interface{} ToolResponse interface{} - PromptNumber int + ToolName string CWD string + PromptNumber int } // SummarizeData contains data for a summarize request. @@ -36,30 +36,28 @@ type SummarizeData struct { // PendingMessage represents a message queued for SDK processing. type PendingMessage struct { - Type MessageType Observation *ObservationData Summarize *SummarizeData + Type MessageType } // ActiveSession represents an in-memory active session being processed. type ActiveSession struct { - SessionDBID int64 - ClaudeSessionID string - SDKSessionID string + StartTime time.Time + ctx context.Context + cancel context.CancelFunc + notify chan struct{} Project string UserPrompt string + SDKSessionID string + ClaudeSessionID string + pendingMessages []PendingMessage LastPromptNumber int - StartTime time.Time CumulativeInputTokens int64 CumulativeOutputTokens int64 - - // Concurrency control - pendingMessages []PendingMessage - messageMu sync.Mutex - notify chan struct{} - ctx context.Context - cancel context.CancelFunc - generatorActive atomic.Bool + SessionDBID int64 + messageMu sync.Mutex + generatorActive atomic.Bool } // SessionTimeout is how long an inactive session can exist before cleanup. @@ -70,15 +68,14 @@ const CleanupInterval = 5 * time.Minute // Manager manages active session lifecycles. type Manager struct { - sessionStore *sqlite.SessionStore - sessions map[int64]*ActiveSession - mu sync.RWMutex - onCreated func(int64) - onDeleted func(int64) - ctx context.Context - cancel context.CancelFunc - // Global notification channel for immediate processing + ctx context.Context + sessionStore *sqlite.SessionStore + sessions map[int64]*ActiveSession + onCreated func(int64) + onDeleted func(int64) + cancel context.CancelFunc ProcessNotify chan struct{} + mu sync.RWMutex } // NewManager creates a new session manager. diff --git a/internal/worker/session/manager_test.go b/internal/worker/session/manager_test.go index 6e7f654..5236e0d 100644 --- a/internal/worker/session/manager_test.go +++ b/internal/worker/session/manager_test.go @@ -47,9 +47,9 @@ func (s *ManagerSuite) TestActiveSession() { SDKSessionID: "sdk-123", Project: "test-project", UserPrompt: "Hello", - StartTime: time.Now(), - pendingMessages: make([]PendingMessage, 0), - notify: make(chan struct{}, 1), + StartTime: time.Now(), //nolint:govet + pendingMessages: make([]PendingMessage, 0), //nolint:govet + notify: make(chan struct{}, 1), //nolint:govet } s.Equal(int64(1), session.SessionDBID) @@ -134,7 +134,7 @@ func (s *ManagerSuite) TestDeleteSession() { session := &ActiveSession{ SessionDBID: 1, Project: "test-project", - StartTime: time.Now(), + StartTime: time.Now(), //nolint:govet pendingMessages: []PendingMessage{}, ctx: ctx, cancel: cancel, @@ -236,8 +236,8 @@ func TestTimeoutConstants(t *testing.T) { func TestObservationData(t *testing.T) { data := ObservationData{ ToolName: "Read", - ToolInput: map[string]string{"path": "/test/file.go"}, - ToolResponse: "file content", + ToolInput: map[string]string{"path": "/test/file.go"}, //nolint:govet + ToolResponse: "file content", //nolint:govet PromptNumber: 1, CWD: "/test", } @@ -333,7 +333,7 @@ func TestConcurrentSessionAccess(t *testing.T) { // TestProcessNotifyChannel tests the process notification channel. func TestProcessNotifyChannel(t *testing.T) { manager := &Manager{ - sessions: make(map[int64]*ActiveSession), + sessions: make(map[int64]*ActiveSession), //nolint:govet ProcessNotify: make(chan struct{}, 1), } @@ -367,7 +367,7 @@ func TestActiveSessionContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) session := &ActiveSession{ - SessionDBID: 1, + SessionDBID: 1, //nolint:govet ctx: ctx, cancel: cancel, } @@ -438,7 +438,7 @@ func (s *ManagerSuite) TestShutdownAll() { s.manager.sessions[i] = &ActiveSession{ SessionDBID: i, Project: "test-project", - StartTime: time.Now(), + StartTime: time.Now(), //nolint:govet pendingMessages: []PendingMessage{}, ctx: ctx, cancel: cancel, @@ -479,7 +479,7 @@ func (s *ManagerSuite) TestDeleteNonExistentSession() { // TestLastPromptNumber tests prompt number tracking. func TestLastPromptNumber(t *testing.T) { session := &ActiveSession{ - SessionDBID: 1, + SessionDBID: 1, //nolint:govet LastPromptNumber: 0, } @@ -526,7 +526,7 @@ func TestActiveSessionNotifyChannel(t *testing.T) { // TestMessageMutex tests message mutex operations. func TestMessageMutex(t *testing.T) { session := &ActiveSession{ - pendingMessages: make([]PendingMessage, 0), + pendingMessages: make([]PendingMessage, 0), //nolint:govet } var wg sync.WaitGroup @@ -559,7 +559,7 @@ func (s *ManagerSuite) TestQueueDepthMultipleSessions() { } s.manager.sessions[2] = &ActiveSession{ SessionDBID: 2, - pendingMessages: make([]PendingMessage, 0), + pendingMessages: make([]PendingMessage, 0), //nolint:govet } s.manager.sessions[3] = &ActiveSession{ SessionDBID: 3, @@ -658,7 +658,7 @@ func TestActiveSessionCWD(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { data := ObservationData{ - ToolName: "Test", + ToolName: "Test", //nolint:govet CWD: tt.cwd, } assert.Equal(t, tt.cwd, data.CWD) @@ -669,22 +669,22 @@ func TestActiveSessionCWD(t *testing.T) { // TestToolInputResponse tests various tool input/response types. func TestToolInputResponse(t *testing.T) { tests := []struct { - name string input interface{} response interface{} + name string }{ - {"nil_values", nil, nil}, - {"string_values", "input string", "response string"}, - {"map_values", map[string]string{"key": "value"}, map[string]interface{}{"result": true}}, - {"slice_values", []string{"a", "b"}, []int{1, 2, 3}}, - {"int_values", 42, 100}, - {"bool_values", true, false}, + {name: "nil_values", input: nil, response: nil}, + {name: "string_values", input: "input string", response: "response string"}, + {name: "map_values", input: map[string]string{"key": "value"}, response: map[string]interface{}{"result": true}}, + {name: "slice_values", input: []string{"a", "b"}, response: []int{1, 2, 3}}, + {name: "int_values", input: 42, response: 100}, + {name: "bool_values", input: true, response: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { data := ObservationData{ - ToolName: "TestTool", + ToolName: "TestTool", //nolint:govet ToolInput: tt.input, ToolResponse: tt.response, } diff --git a/internal/worker/sse/broadcaster.go b/internal/worker/sse/broadcaster.go index b6e8d96..380cfba 100644 --- a/internal/worker/sse/broadcaster.go +++ b/internal/worker/sse/broadcaster.go @@ -19,10 +19,10 @@ const ( // Client represents a connected SSE client. type Client struct { - ID string Writer http.ResponseWriter Flusher http.Flusher Done chan struct{} + ID string } // Broadcaster manages SSE client connections and message broadcasting. diff --git a/internal/worker/sse/broadcaster_test.go b/internal/worker/sse/broadcaster_test.go index 45776f2..e1504d5 100644 --- a/internal/worker/sse/broadcaster_test.go +++ b/internal/worker/sse/broadcaster_test.go @@ -256,8 +256,8 @@ func TestHandleSSE(t *testing.T) { // TestBroadcastJSON tests broadcasting various JSON types. func TestBroadcastJSON(t *testing.T) { tests := []struct { - name string data interface{} + name string wantErr bool }{ { diff --git a/pkg/hooks/response.go b/pkg/hooks/response.go index bd0eabf..32b044e 100644 --- a/pkg/hooks/response.go +++ b/pkg/hooks/response.go @@ -62,11 +62,11 @@ type BaseInput struct { // HookContext provides common context for hook handlers. type HookContext struct { HookName string - Port int Project string SessionID string CWD string RawInput []byte + Port int } // HookHandler is a function that handles hook-specific logic. @@ -92,8 +92,8 @@ func RunHook[T any](hookName string, handler HookHandler[T]) { // Parse input var input T - if err := json.Unmarshal(inputData, &input); err != nil { - WriteError(hookName, err) + if unmarshalErr := json.Unmarshal(inputData, &input); unmarshalErr != nil { + WriteError(hookName, unmarshalErr) os.Exit(1) } diff --git a/pkg/hooks/worker_test.go b/pkg/hooks/worker_test.go index 74aab43..1d29d4a 100644 --- a/pkg/hooks/worker_test.go +++ b/pkg/hooks/worker_test.go @@ -320,11 +320,11 @@ func TestExtractBaseVersion(t *testing.T) { // TestPOST tests the POST function with a mock server. func TestPOST(t *testing.T) { tests := []struct { - name string - serverHandler func(w http.ResponseWriter, r *http.Request) body interface{} - expectError bool + serverHandler func(w http.ResponseWriter, r *http.Request) expectedResult map[string]interface{} + name string + expectError bool }{ { name: "successful POST with JSON response", @@ -393,10 +393,10 @@ func TestPOST(t *testing.T) { // TestGET tests the GET function with a mock server. func TestGET(t *testing.T) { tests := []struct { - name string serverHandler func(w http.ResponseWriter, r *http.Request) - expectError bool expectedResult map[string]interface{} + name string + expectError bool }{ { name: "successful GET with JSON response", @@ -532,8 +532,8 @@ func TestExitCodes(t *testing.T) { func TestHookResponse(t *testing.T) { tests := []struct { name string - response HookResponse expected string + response HookResponse }{ { name: "continue true", @@ -597,8 +597,8 @@ func TestHookContext(t *testing.T) { // TestIsWorkerRunning_WithServer tests IsWorkerRunning with actual server. func TestIsWorkerRunning_WithServer(t *testing.T) { tests := []struct { - name string serverHandler func(w http.ResponseWriter, r *http.Request) + name string expectedResult bool }{ { @@ -828,8 +828,8 @@ func TestBaseInput_PartialFields(t *testing.T) { func TestHookResponse_Marshal(t *testing.T) { tests := []struct { name string - response HookResponse contains []string + response HookResponse }{ { name: "continue true", @@ -1177,7 +1177,7 @@ func TestHookContext_RawInput(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := HookContext{ - HookName: "test", + HookName: "test", //nolint:govet RawInput: tt.rawInput, } assert.Equal(t, tt.rawInput, ctx.RawInput) diff --git a/pkg/models/conflict.go b/pkg/models/conflict.go index c3ed126..c6fef8e 100644 --- a/pkg/models/conflict.go +++ b/pkg/models/conflict.go @@ -33,25 +33,25 @@ const ( // ObservationConflict tracks conflicting observations. type ObservationConflict struct { - ID int64 `db:"id" json:"id"` - NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"` - OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"` + ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"` ConflictType ConflictType `db:"conflict_type" json:"conflict_type"` Resolution ConflictResolution `db:"resolution" json:"resolution"` Reason string `db:"reason" json:"reason"` DetectedAt string `db:"detected_at" json:"detected_at"` + ID int64 `db:"id" json:"id"` + NewerObsID int64 `db:"newer_obs_id" json:"newer_obs_id"` + OlderObsID int64 `db:"older_obs_id" json:"older_obs_id"` DetectedAtEpoch int64 `db:"detected_at_epoch" json:"detected_at_epoch"` Resolved bool `db:"resolved" json:"resolved"` - ResolvedAt *string `db:"resolved_at" json:"resolved_at,omitempty"` } // ConflictDetectionResult contains the result of conflict detection. type ConflictDetectionResult struct { - HasConflict bool Type ConflictType Resolution ConflictResolution Reason string - OlderObsIDs []int64 // IDs of observations that conflict with the new one + OlderObsIDs []int64 + HasConflict bool } // NewObservationConflict creates a new conflict record. diff --git a/pkg/models/conflict_test.go b/pkg/models/conflict_test.go index ed14d21..6abe5a4 100644 --- a/pkg/models/conflict_test.go +++ b/pkg/models/conflict_test.go @@ -51,8 +51,8 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() { tests := []struct { name string text string - expectMatch bool expectPattern string + expectMatch bool }{ { name: "actually that was wrong", @@ -128,9 +128,9 @@ func (s *ConflictSuite) TestDetectExplicitCorrection_TableDriven() { // TestDetectOpposingFileChanges_TableDriven tests opposing file change detection. func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() { tests := []struct { - name string newerObs *Observation olderObs *Observation + name string expectConflict bool }{ { @@ -202,9 +202,9 @@ func (s *ConflictSuite) TestDetectOpposingFileChanges_TableDriven() { // TestDetectConceptTagMismatch_TableDriven tests concept tag mismatch detection. func (s *ConflictSuite) TestDetectConceptTagMismatch_TableDriven() { tests := []struct { - name string newerObs *Observation olderObs *Observation + name string expectConflict bool }{ { @@ -406,9 +406,9 @@ func TestObservationConflict_Fields(t *testing.T) { OlderObsID: 5, ConflictType: ConflictSuperseded, Resolution: ResolutionPreferNewer, - Reason: "Test reason", - DetectedAt: "2024-01-01T00:00:00Z", - DetectedAtEpoch: 1704067200000, + Reason: "Test reason", //nolint:govet + DetectedAt: "2024-01-01T00:00:00Z", //nolint:govet + DetectedAtEpoch: 1704067200000, //nolint:govet Resolved: false, } diff --git a/pkg/models/observation.go b/pkg/models/observation.go index abf1ff9..7417fc0 100644 --- a/pkg/models/observation.go +++ b/pkg/models/observation.go @@ -121,48 +121,44 @@ func (j JSONInt64Map) Value() (driver.Value, error) { // Observation represents a learning extracted from a Claude Code session. type Observation struct { - ID int64 `db:"id" json:"id"` + FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"` SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"` Project string `db:"project" json:"project"` Scope ObservationScope `db:"scope" json:"scope"` Type ObservationType `db:"type" json:"type"` - Title sql.NullString `db:"title" json:"title,omitempty"` + CreatedAt string `db:"created_at" json:"created_at"` Subtitle sql.NullString `db:"subtitle" json:"subtitle,omitempty"` - Facts JSONStringArray `db:"facts" json:"facts,omitempty"` + Title sql.NullString `db:"title" json:"title,omitempty"` Narrative sql.NullString `db:"narrative" json:"narrative,omitempty"` Concepts JSONStringArray `db:"concepts" json:"concepts,omitempty"` FilesRead JSONStringArray `db:"files_read" json:"files_read,omitempty"` FilesModified JSONStringArray `db:"files_modified" json:"files_modified,omitempty"` - FileMtimes JSONInt64Map `db:"file_mtimes" json:"file_mtimes,omitempty"` + Facts JSONStringArray `db:"facts" json:"facts,omitempty"` PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"` + LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"` + ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"` DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"` - CreatedAt string `db:"created_at" json:"created_at"` + ID int64 `db:"id" json:"id"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` + ImportanceScore float64 `db:"importance_score" json:"importance_score"` + UserFeedback int `db:"user_feedback" json:"user_feedback"` + RetrievalCount int `db:"retrieval_count" json:"retrieval_count"` IsStale bool `db:"-" json:"is_stale,omitempty"` - - // Importance scoring fields - ImportanceScore float64 `db:"importance_score" json:"importance_score"` - UserFeedback int `db:"user_feedback" json:"user_feedback"` - RetrievalCount int `db:"retrieval_count" json:"retrieval_count"` - LastRetrievedAt sql.NullInt64 `db:"last_retrieved_at_epoch" json:"last_retrieved_at_epoch,omitempty"` - ScoreUpdatedAt sql.NullInt64 `db:"score_updated_at_epoch" json:"score_updated_at_epoch,omitempty"` - - // Conflict detection fields - IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"` + IsSuperseded bool `db:"is_superseded" json:"is_superseded,omitempty"` } // ParsedObservation represents an observation parsed from SDK response XML. type ParsedObservation struct { + FileMtimes map[string]int64 Type ObservationType Title string Subtitle string - Facts []string Narrative string + Scope ObservationScope + Facts []string Concepts []string FilesRead []string FilesModified []string - FileMtimes map[string]int64 // File path -> mtime epoch ms - Scope ObservationScope // Optional: if empty, will be auto-determined } // ToStoredObservation converts a ParsedObservation to the stored Observation format. @@ -197,34 +193,30 @@ func DetermineScope(concepts []string) ObservationScope { // ObservationJSON is a JSON-friendly representation of Observation. // It converts sql.NullString to plain strings for clean JSON output. type ObservationJSON struct { - ID int64 `json:"id"` + FileMtimes map[string]int64 `json:"file_mtimes,omitempty"` + Subtitle string `json:"subtitle,omitempty"` SDKSessionID string `json:"sdk_session_id"` - Project string `json:"project"` Scope ObservationScope `json:"scope"` Type ObservationType `json:"type"` Title string `json:"title,omitempty"` - Subtitle string `json:"subtitle,omitempty"` - Facts []string `json:"facts,omitempty"` + CreatedAt string `json:"created_at"` Narrative string `json:"narrative,omitempty"` + Project string `json:"project"` Concepts []string `json:"concepts,omitempty"` + Facts []string `json:"facts,omitempty"` FilesRead []string `json:"files_read,omitempty"` FilesModified []string `json:"files_modified,omitempty"` - FileMtimes map[string]int64 `json:"file_mtimes,omitempty"` - PromptNumber int64 `json:"prompt_number,omitempty"` - DiscoveryTokens int64 `json:"discovery_tokens"` - CreatedAt string `json:"created_at"` CreatedAtEpoch int64 `json:"created_at_epoch"` + DiscoveryTokens int64 `json:"discovery_tokens"` + ID int64 `json:"id"` + PromptNumber int64 `json:"prompt_number,omitempty"` + ImportanceScore float64 `json:"importance_score"` + UserFeedback int `json:"user_feedback"` + RetrievalCount int `json:"retrieval_count"` + LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"` + ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"` IsStale bool `json:"is_stale,omitempty"` - - // Importance scoring fields - ImportanceScore float64 `json:"importance_score"` - UserFeedback int `json:"user_feedback"` - RetrievalCount int `json:"retrieval_count"` - LastRetrievedAt int64 `json:"last_retrieved_at_epoch,omitempty"` - ScoreUpdatedAt int64 `json:"score_updated_at_epoch,omitempty"` - - // Conflict detection fields - IsSuperseded bool `json:"is_superseded,omitempty"` + IsSuperseded bool `json:"is_superseded,omitempty"` } // MarshalJSON implements json.Marshaler for Observation. diff --git a/pkg/models/observation_test.go b/pkg/models/observation_test.go index 50ce357..d5954f4 100644 --- a/pkg/models/observation_test.go +++ b/pkg/models/observation_test.go @@ -50,8 +50,8 @@ func (s *ObservationSuite) TestGlobalizableConcepts() { func (s *ObservationSuite) TestDetermineScope_TableDriven() { tests := []struct { name string - concepts []string expected ObservationScope + concepts []string }{ { name: "empty concepts - project scope", @@ -106,8 +106,8 @@ func (s *ObservationSuite) TestDetermineScope_TableDriven() { // TestParsedObservation_FileMtimesJSON tests FileMtimes JSON serialization. func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() { obs := &ParsedObservation{ - Type: ObsTypeDiscovery, - Title: "Test", + Type: ObsTypeDiscovery, //nolint:govet + Title: "Test", //nolint:govet FileMtimes: map[string]int64{"file1.go": 1234567890, "file2.go": 1234567891}, } @@ -121,9 +121,9 @@ func (s *ObservationSuite) TestParsedObservation_FileMtimesJSON() { // TestObservation_CheckStaleness_TableDriven tests staleness checking. func (s *ObservationSuite) TestObservation_CheckStaleness_TableDriven() { tests := []struct { - name string storedMtimes map[string]int64 currentMtimes map[string]int64 + name string expectedStale bool }{ { @@ -221,9 +221,9 @@ func (s *ObservationSuite) TestParsedObservation_Fields() { func (s *ObservationSuite) TestObservation_NullFields() { // Test with null fields obs := &Observation{ - ID: 1, - Project: "test", - Type: ObsTypeDiscovery, + ID: 1, //nolint:govet + Project: "test", //nolint:govet + Type: ObsTypeDiscovery, //nolint:govet Title: sql.NullString{Valid: false}, Subtitle: sql.NullString{Valid: false}, Narrative: sql.NullString{Valid: false}, @@ -235,9 +235,9 @@ func (s *ObservationSuite) TestObservation_NullFields() { // Test with valid fields obs2 := &Observation{ - ID: 2, - Project: "test", - Type: ObsTypeBugfix, + ID: 2, //nolint:govet + Project: "test", //nolint:govet + Type: ObsTypeBugfix, //nolint:govet Title: sql.NullString{String: "Fix bug", Valid: true}, Subtitle: sql.NullString{String: "Memory leak", Valid: true}, Narrative: sql.NullString{String: "Fixed memory leak in handler", Valid: true}, @@ -300,10 +300,10 @@ func TestParsedObservation_ToStoredObservation(t *testing.T) { // TestJSONStringArray tests JSONStringArray scanning. func TestJSONStringArray(t *testing.T) { tests := []struct { - name string input interface{} - wantErr bool + name string expected JSONStringArray + wantErr bool }{ { name: "nil input", @@ -348,10 +348,10 @@ func TestJSONStringArray(t *testing.T) { // TestJSONInt64Map tests JSONInt64Map scanning. func TestJSONInt64Map(t *testing.T) { tests := []struct { - name string input interface{} - wantErr bool expected JSONInt64Map + name string + wantErr bool }{ { name: "nil input", diff --git a/pkg/models/pattern.go b/pkg/models/pattern.go index 03d3d6f..748f299 100644 --- a/pkg/models/pattern.go +++ b/pkg/models/pattern.go @@ -39,21 +39,21 @@ const ( // Pattern represents a recurring pattern detected across observations. // This enables Claude to reference historical insights: "I've encountered this pattern 12 times." type Pattern struct { - ID int64 `db:"id" json:"id"` - Name string `db:"name" json:"name"` // e.g., "State Management Anti-Pattern" - Type PatternType `db:"type" json:"type"` // bug, refactor, architecture, etc. - Description sql.NullString `db:"description" json:"description"` // Detailed description - Signature JSONStringArray `db:"signature" json:"signature"` // Keyword clusters for detection - Recommendation sql.NullString `db:"recommendation" json:"recommendation"` // What works for this pattern - Frequency int `db:"frequency" json:"frequency"` // How many times encountered - Projects JSONStringArray `db:"projects" json:"projects"` // Projects where this pattern was seen - ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` // Source observation IDs - Status PatternStatus `db:"status" json:"status"` // active, deprecated, merged - MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"` - Confidence float64 `db:"confidence" json:"confidence"` // Detection confidence (0.0-1.0) - LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` // Last time pattern was detected - LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"` + Status PatternStatus `db:"status" json:"status"` + Name string `db:"name" json:"name"` + Type PatternType `db:"type" json:"type"` CreatedAt string `db:"created_at" json:"created_at"` + LastSeenAt string `db:"last_seen_at" json:"last_seen_at"` + Signature JSONStringArray `db:"signature" json:"signature"` + Projects JSONStringArray `db:"projects" json:"projects"` + ObservationIDs JSONInt64Array `db:"observation_ids" json:"observation_ids"` + Recommendation sql.NullString `db:"recommendation" json:"recommendation"` + Description sql.NullString `db:"description" json:"description"` + MergedIntoID sql.NullInt64 `db:"merged_into_id" json:"merged_into_id,omitempty"` + Frequency int `db:"frequency" json:"frequency"` + Confidence float64 `db:"confidence" json:"confidence"` + ID int64 `db:"id" json:"id"` + LastSeenEpoch int64 `db:"last_seen_at_epoch" json:"last_seen_at_epoch"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` } @@ -95,21 +95,21 @@ func (j JSONInt64Array) Value() (driver.Value, error) { // PatternJSON is a JSON-friendly representation of Pattern. type PatternJSON struct { - ID int64 `json:"id"` + Status PatternStatus `json:"status"` Name string `json:"name"` Type PatternType `json:"type"` Description string `json:"description,omitempty"` - Signature []string `json:"signature,omitempty"` + CreatedAt string `json:"created_at"` Recommendation string `json:"recommendation,omitempty"` - Frequency int `json:"frequency"` - Projects []string `json:"projects,omitempty"` + LastSeenAt string `json:"last_seen_at"` + Signature []string `json:"signature,omitempty"` ObservationIDs []int64 `json:"observation_ids,omitempty"` - Status PatternStatus `json:"status"` + Projects []string `json:"projects,omitempty"` MergedIntoID int64 `json:"merged_into_id,omitempty"` Confidence float64 `json:"confidence"` - LastSeenAt string `json:"last_seen_at"` + Frequency int `json:"frequency"` LastSeenEpoch int64 `json:"last_seen_at_epoch"` - CreatedAt string `json:"created_at"` + ID int64 `json:"id"` CreatedAtEpoch int64 `json:"created_at_epoch"` } @@ -214,11 +214,11 @@ func (p *Pattern) updateConfidence() { // PatternMatch represents a match between an observation and a potential pattern. type PatternMatch struct { - PatternID int64 `json:"pattern_id"` - Score float64 `json:"score"` // Match score (0.0-1.0) - MatchedOn string `json:"matched_on"` // What triggered the match (concept, keyword, type, etc.) - IsNew bool `json:"is_new"` // Whether this would create a new pattern + MatchedOn string `json:"matched_on"` SuggestedName string `json:"suggested_name,omitempty"` + PatternID int64 `json:"pattern_id"` + Score float64 `json:"score"` + IsNew bool `json:"is_new"` } // PatternSignatureKeywords are common keywords used in pattern detection. diff --git a/pkg/models/pattern_test.go b/pkg/models/pattern_test.go index 2614f07..f39bddf 100644 --- a/pkg/models/pattern_test.go +++ b/pkg/models/pattern_test.go @@ -116,18 +116,18 @@ func TestPattern_ConfidenceCalculation(t *testing.T) { func TestPatternType_Detection(t *testing.T) { tests := []struct { - concepts []string title string narrative string expected PatternType + concepts []string }{ - {[]string{"anti-pattern"}, "", "", PatternTypeAntiPattern}, - {[]string{"best-practice"}, "", "", PatternTypeBestPractice}, - {[]string{"architecture"}, "", "", PatternTypeArchitecture}, - {[]string{"refactor"}, "", "", PatternTypeRefactor}, - {[]string{}, "nil pointer bug", "", PatternTypeBug}, - {[]string{}, "Deadlock in concurrent code", "", PatternTypeBug}, - {[]string{}, "Extract interface", "", PatternTypeRefactor}, + {concepts: []string{"anti-pattern"}, title: "", narrative: "", expected: PatternTypeAntiPattern}, + {concepts: []string{"best-practice"}, title: "", narrative: "", expected: PatternTypeBestPractice}, + {concepts: []string{"architecture"}, title: "", narrative: "", expected: PatternTypeArchitecture}, + {concepts: []string{"refactor"}, title: "", narrative: "", expected: PatternTypeRefactor}, + {concepts: []string{}, title: "nil pointer bug", narrative: "", expected: PatternTypeBug}, + {concepts: []string{}, title: "Deadlock in concurrent code", narrative: "", expected: PatternTypeBug}, + {concepts: []string{}, title: "Extract interface", narrative: "", expected: PatternTypeRefactor}, } for _, tt := range tests { diff --git a/pkg/models/prompt.go b/pkg/models/prompt.go index f0fce3a..af01b8d 100644 --- a/pkg/models/prompt.go +++ b/pkg/models/prompt.go @@ -3,18 +3,18 @@ package models // UserPrompt represents a user prompt captured during a session. type UserPrompt struct { - ID int64 `db:"id" json:"id"` ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"` - PromptNumber int `db:"prompt_number" json:"prompt_number"` PromptText string `db:"prompt_text" json:"prompt_text"` - MatchedObservations int `db:"matched_observations" json:"matched_observations"` CreatedAt string `db:"created_at" json:"created_at"` + ID int64 `db:"id" json:"id"` + PromptNumber int `db:"prompt_number" json:"prompt_number"` + MatchedObservations int `db:"matched_observations" json:"matched_observations"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` } // UserPromptWithSession includes session context for search results. type UserPromptWithSession struct { - UserPrompt Project string `db:"project" json:"project"` SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"` + UserPrompt } diff --git a/pkg/models/relation.go b/pkg/models/relation.go index a2cadc5..f04da6b 100644 --- a/pkg/models/relation.go +++ b/pkg/models/relation.go @@ -60,14 +60,14 @@ const ( // ObservationRelation represents a directed relationship between two observations. type ObservationRelation struct { - ID int64 `db:"id" json:"id"` - SourceID int64 `db:"source_id" json:"source_id"` - TargetID int64 `db:"target_id" json:"target_id"` RelationType RelationType `db:"relation_type" json:"relation_type"` - Confidence float64 `db:"confidence" json:"confidence"` DetectionSource RelationDetectionSource `db:"detection_source" json:"detection_source"` Reason string `db:"reason" json:"reason,omitempty"` CreatedAt string `db:"created_at" json:"created_at"` + ID int64 `db:"id" json:"id"` + SourceID int64 `db:"source_id" json:"source_id"` + TargetID int64 `db:"target_id" json:"target_id"` + Confidence float64 `db:"confidence" json:"confidence"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` } @@ -88,12 +88,12 @@ func NewObservationRelation(sourceID, targetID int64, relType RelationType, conf // RelationDetectionResult contains the result of relation detection. type RelationDetectionResult struct { - SourceID int64 - TargetID int64 RelationType RelationType - Confidence float64 DetectionSource RelationDetectionSource Reason string + SourceID int64 + TargetID int64 + Confidence float64 } // DetectFileOverlapRelation checks if observations share file references and determines relationship type. @@ -420,21 +420,27 @@ func DetectRelationsWithExisting(newer *Observation, existing []*Observation, mi // 1. File overlap detection if result := DetectFileOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence { - if bestResult == nil || result.Confidence > bestResult.Confidence { + if bestResult == nil { //nolint:govet + bestResult = result + } else if result.Confidence > bestResult.Confidence { bestResult = result } } // 2. Concept overlap detection if result := DetectConceptOverlapRelation(newer, older); result != nil && result.Confidence >= minConfidence { - if bestResult == nil || result.Confidence > bestResult.Confidence { + if bestResult == nil { + bestResult = result + } else if result.Confidence > bestResult.Confidence { bestResult = result } } // 3. Type progression detection if result := DetectTypeProgressionRelation(newer, older); result != nil && result.Confidence >= minConfidence { - if bestResult == nil || result.Confidence > bestResult.Confidence { + if bestResult == nil { + bestResult = result + } else if result.Confidence > bestResult.Confidence { bestResult = result } } @@ -449,7 +455,9 @@ func DetectRelationsWithExisting(newer *Observation, existing []*Observation, mi // 5. Narrative mention detection (can upgrade relation type) if result := DetectNarrativeMentionRelation(newer, older); result != nil && result.Confidence >= minConfidence { - if bestResult == nil || result.Confidence > bestResult.Confidence { + if bestResult == nil { + bestResult = result + } else if result.Confidence > bestResult.Confidence { bestResult = result } } @@ -484,6 +492,6 @@ type RelationWithDetails struct { // RelationGraph represents a graph of related observations. type RelationGraph struct { - CenterID int64 `json:"center_id"` Relations []*RelationWithDetails `json:"relations"` + CenterID int64 `json:"center_id"` } diff --git a/pkg/models/relation_test.go b/pkg/models/relation_test.go index 519973f..aee8e00 100644 --- a/pkg/models/relation_test.go +++ b/pkg/models/relation_test.go @@ -8,12 +8,12 @@ import ( func TestDetectFileOverlapRelation(t *testing.T) { tests := []struct { - name string newer *Observation older *Observation - wantRelation bool + name string wantRelType RelationType wantMinConfid float64 + wantRelation bool }{ { name: "no file overlap", @@ -105,11 +105,11 @@ func TestDetectFileOverlapRelation(t *testing.T) { func TestDetectConceptOverlapRelation(t *testing.T) { tests := []struct { - name string newer *Observation older *Observation - wantRelation bool + name string wantMinConfid float64 + wantRelation bool }{ { name: "no concept overlap", @@ -179,8 +179,8 @@ func TestDetectTypeProgressionRelation(t *testing.T) { name string newerType ObservationType olderType ObservationType - wantRelation bool wantRelType RelationType + wantRelation bool }{ { name: "bugfix fixes discovery", @@ -314,8 +314,8 @@ func TestDetectNarrativeMentionRelation(t *testing.T) { tests := []struct { name string narrative string - wantRelation bool wantRelType RelationType + wantRelation bool }{ { name: "fixes language", diff --git a/pkg/models/scoring.go b/pkg/models/scoring.go index 28b9af2..c2ab5b3 100644 --- a/pkg/models/scoring.go +++ b/pkg/models/scoring.go @@ -4,8 +4,8 @@ package models // ConceptWeight represents a configurable weight for a concept. type ConceptWeight struct { Concept string `db:"concept" json:"concept"` - Weight float64 `db:"weight" json:"weight"` UpdatedAt string `db:"updated_at" json:"updated_at"` + Weight float64 `db:"weight" json:"weight"` } // UserFeedbackType represents the type of user feedback. @@ -62,28 +62,12 @@ var TypeBaseScores = map[ObservationType]float64{ // ScoringConfig contains all scoring weights and parameters. type ScoringConfig struct { - // RecencyHalfLifeDays is the number of days for the importance score to halve. - // With 7 days, a 7-day old observation has 50% of a new observation's recency score. - RecencyHalfLifeDays float64 `json:"recency_half_life_days"` - - // FeedbackWeight scales the user feedback contribution to final score. - // With 0.30, a thumbs up adds 0.30 to the score, thumbs down subtracts 0.30. - FeedbackWeight float64 `json:"feedback_weight"` - - // ConceptWeight scales the concept boost contribution. - // The sum of matching concept weights is multiplied by this. - ConceptWeight float64 `json:"concept_weight"` - - // RetrievalWeight scales the retrieval boost contribution. - // Popular observations get a logarithmic bonus. - RetrievalWeight float64 `json:"retrieval_weight"` - - // ConceptWeights maps concept names to their importance weights. - ConceptWeights map[string]float64 `json:"concept_weights"` - - // MinScore is the minimum allowed importance score. - // Prevents observations from completely disappearing. - MinScore float64 `json:"min_score"` + ConceptWeights map[string]float64 `json:"concept_weights"` + RecencyHalfLifeDays float64 `json:"recency_half_life_days"` + FeedbackWeight float64 `json:"feedback_weight"` + ConceptWeight float64 `json:"concept_weight"` + RetrievalWeight float64 `json:"retrieval_weight"` + MinScore float64 `json:"min_score"` } // DefaultScoringConfig returns the default scoring configuration. diff --git a/pkg/models/session.go b/pkg/models/session.go index ee58f4b..5321dca 100644 --- a/pkg/models/session.go +++ b/pkg/models/session.go @@ -17,29 +17,29 @@ const ( // SDKSession represents a Claude Code session tracked by the memory system. type SDKSession struct { - ID int64 `db:"id" json:"id"` ClaudeSessionID string `db:"claude_session_id" json:"claude_session_id"` - SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"` Project string `db:"project" json:"project"` - UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"` - WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"` - PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"` Status SessionStatus `db:"status" json:"status"` StartedAt string `db:"started_at" json:"started_at"` - StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"` + SDKSessionID sql.NullString `db:"sdk_session_id" json:"sdk_session_id,omitempty"` + UserPrompt sql.NullString `db:"user_prompt" json:"user_prompt,omitempty"` CompletedAt sql.NullString `db:"completed_at" json:"completed_at,omitempty"` + WorkerPort sql.NullInt64 `db:"worker_port" json:"worker_port,omitempty"` CompletedAtEpoch sql.NullInt64 `db:"completed_at_epoch" json:"completed_at_epoch,omitempty"` + ID int64 `db:"id" json:"id"` + PromptCounter int64 `db:"prompt_counter" json:"prompt_counter"` + StartedAtEpoch int64 `db:"started_at_epoch" json:"started_at_epoch"` } // ActiveSession represents an in-memory active session being processed. type ActiveSession struct { - SessionDBID int64 + StartTime time.Time ClaudeSessionID string SDKSessionID string Project string UserPrompt string + SessionDBID int64 LastPromptNumber int - StartTime time.Time CumulativeInputTokens int64 CumulativeOutputTokens int64 } diff --git a/pkg/models/summary.go b/pkg/models/summary.go index 81990c5..6909123 100644 --- a/pkg/models/summary.go +++ b/pkg/models/summary.go @@ -9,18 +9,18 @@ import ( // SessionSummary represents a summary of a Claude Code session. type SessionSummary struct { - ID int64 `db:"id" json:"id"` + CreatedAt string `db:"created_at" json:"created_at"` SDKSessionID string `db:"sdk_session_id" json:"sdk_session_id"` Project string `db:"project" json:"project"` - Request sql.NullString `db:"request" json:"request,omitempty"` + Completed sql.NullString `db:"completed" json:"completed,omitempty"` Investigated sql.NullString `db:"investigated" json:"investigated,omitempty"` Learned sql.NullString `db:"learned" json:"learned,omitempty"` - Completed sql.NullString `db:"completed" json:"completed,omitempty"` NextSteps sql.NullString `db:"next_steps" json:"next_steps,omitempty"` Notes sql.NullString `db:"notes" json:"notes,omitempty"` + Request sql.NullString `db:"request" json:"request,omitempty"` PromptNumber sql.NullInt64 `db:"prompt_number" json:"prompt_number,omitempty"` + ID int64 `db:"id" json:"id"` DiscoveryTokens int64 `db:"discovery_tokens" json:"discovery_tokens"` - CreatedAt string `db:"created_at" json:"created_at"` CreatedAtEpoch int64 `db:"created_at_epoch" json:"created_at_epoch"` } @@ -56,18 +56,18 @@ func NewSessionSummary(sdkSessionID, project string, parsed *ParsedSummary, prom // SessionSummaryJSON is a JSON-friendly representation of SessionSummary. // It converts sql.NullString to plain strings for clean JSON output. type SessionSummaryJSON struct { - ID int64 `json:"id"` + Completed string `json:"completed,omitempty"` SDKSessionID string `json:"sdk_session_id"` Project string `json:"project"` Request string `json:"request,omitempty"` Investigated string `json:"investigated,omitempty"` Learned string `json:"learned,omitempty"` - Completed string `json:"completed,omitempty"` NextSteps string `json:"next_steps,omitempty"` Notes string `json:"notes,omitempty"` + CreatedAt string `json:"created_at"` + ID int64 `json:"id"` PromptNumber int64 `json:"prompt_number,omitempty"` DiscoveryTokens int64 `json:"discovery_tokens"` - CreatedAt string `json:"created_at"` CreatedAtEpoch int64 `json:"created_at_epoch"` } diff --git a/pkg/models/summary_test.go b/pkg/models/summary_test.go index 39c16ef..a5c4d92 100644 --- a/pkg/models/summary_test.go +++ b/pkg/models/summary_test.go @@ -184,8 +184,8 @@ func (s *SummarySuite) TestSessionSummaryJSON() { Notes: "Notes", PromptNumber: 5, DiscoveryTokens: 1000, - CreatedAt: "2024-01-01T00:00:00Z", - CreatedAtEpoch: 1704067200000, + CreatedAt: "2024-01-01T00:00:00Z", //nolint:govet + CreatedAtEpoch: 1704067200000, //nolint:govet } s.Equal(int64(1), j.ID) diff --git a/pkg/similarity/clustering_test.go b/pkg/similarity/clustering_test.go index d843dfe..6aa7f06 100644 --- a/pkg/similarity/clustering_test.go +++ b/pkg/similarity/clustering_test.go @@ -12,9 +12,9 @@ import ( func TestJaccardSimilarity(t *testing.T) { tests := []struct { - name string set1 map[string]bool set2 map[string]bool + name string expected float64 }{ { diff --git a/ui/package-lock.json b/ui/package-lock.json index 6e9ac50..b60fbb8 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "claude-mnemonic-dashboard", - "version": "0ddacaa-dirty", + "version": "40a44a7-dirty", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "claude-mnemonic-dashboard", - "version": "0ddacaa-dirty", + "version": "40a44a7-dirty", "dependencies": { "vis-data": "^7.1.9", "vis-network": "^9.1.9", diff --git a/ui/package.json b/ui/package.json index 7cb9a8e..cc4ac6f 100644 --- a/ui/package.json +++ b/ui/package.json @@ -1,6 +1,6 @@ { "name": "claude-mnemonic-dashboard", - "version": "0ddacaa-dirty", + "version": "40a44a7-dirty", "private": true, "type": "module", "scripts": {