From e07d4174de518b7d41103b32fa92d282c869b54c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 9 Jan 2026 22:17:05 +0000 Subject: [PATCH] fix(hooks,db,mcp,worker): add type safety and error handling (#21) - [x] Add type checking and error handling for JSON type assertions in user-prompt hook - [x] Add error handling for session update query in CreateSDKSession - [x] Update MCP tool description to reference sqlite-vec instead of ChromaDB - [x] Fix MinConfidence sentinel value check from 0 to -1 - [x] Pass project parameter to vector search filter in handleSearchByPrompt - [x] Return empty map instead of nil for successful responses without JSON body --- cmd/hooks/user-prompt/main.go | 14 ++++++++++++-- internal/db/gorm/session_store.go | 7 +++++-- internal/mcp/server.go | 5 +++-- internal/worker/handlers.go | 2 +- pkg/hooks/worker.go | 4 ++-- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/cmd/hooks/user-prompt/main.go b/cmd/hooks/user-prompt/main.go index 8edbe0f..b2093a2 100644 --- a/cmd/hooks/user-prompt/main.go +++ b/cmd/hooks/user-prompt/main.go @@ -94,8 +94,18 @@ func handleUserPrompt(ctx *hooks.HookContext, input *Input) (string, error) { return "", nil } - sessionID := int64(result["sessionDbId"].(float64)) - promptNumber := int(result["promptNumber"].(float64)) + // Safely extract session ID and prompt number with type checking + sessionDbIdRaw, ok := result["sessionDbId"].(float64) + if !ok { + return "", fmt.Errorf("invalid or missing sessionDbId in response") + } + sessionID := int64(sessionDbIdRaw) + + promptNumberRaw, ok := result["promptNumber"].(float64) + if !ok { + return "", fmt.Errorf("invalid or missing promptNumber in response") + } + promptNumber := int(promptNumberRaw) fmt.Fprintf(os.Stderr, "[user-prompt] Session %d, prompt #%d\n", sessionID, promptNumber) diff --git a/internal/db/gorm/session_store.go b/internal/db/gorm/session_store.go index 0dc166c..2daf3fe 100644 --- a/internal/db/gorm/session_store.go +++ b/internal/db/gorm/session_store.go @@ -4,6 +4,7 @@ package gorm import ( "context" "database/sql" + "fmt" "time" "gorm.io/gorm" @@ -67,10 +68,12 @@ func (s *SessionStore) CreateSDKSession(ctx context.Context, claudeSessionID, pr if userPrompt != "" { updates["user_prompt"] = userPrompt } - s.db.WithContext(ctx). + if err := s.db.WithContext(ctx). Model(&SDKSession{}). Where("claude_session_id = ?", claudeSessionID). - Updates(updates) + Updates(updates).Error; err != nil { + return 0, fmt.Errorf("failed to update session: %w", err) + } } // Fetch existing session diff --git a/internal/mcp/server.go b/internal/mcp/server.go index c30fc5a..de9649f 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -164,7 +164,7 @@ func (s *Server) handleToolsList(req *Request) *Response { tools := []Tool{ { Name: "search", - Description: "Unified search across all memory types (observations, sessions, and user prompts) using vector-first semantic search (ChromaDB).", + Description: "Unified search across all memory types (observations, sessions, and user prompts) using vector-first semantic search (sqlite-vec).", InputSchema: map[string]any{ "type": "object", "properties": map[string]any{ @@ -599,7 +599,8 @@ func (s *Server) handleFindRelatedObservations(ctx context.Context, args json.Ra return "", fmt.Errorf("id is required") } - if params.MinConfidence == 0 { + // Use -1 as sentinel for "not provided" since 0.0 is a valid threshold + if params.MinConfidence < 0 { params.MinConfidence = 0.5 } diff --git a/internal/worker/handlers.go b/internal/worker/handlers.go index 40593a5..f22dea6 100644 --- a/internal/worker/handlers.go +++ b/internal/worker/handlers.go @@ -774,7 +774,7 @@ func (s *Service) handleSearchByPrompt(w http.ResponseWriter, r *http.Request) { // Try vector search first if available if s.vectorClient != nil && s.vectorClient.IsConnected() { - where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, "") + where := sqlitevec.BuildWhereFilter(sqlitevec.DocTypeObservation, project) // Search with each expanded query and merge results allVectorResults := make([]sqlitevec.QueryResult, 0) diff --git a/pkg/hooks/worker.go b/pkg/hooks/worker.go index 168b8ae..a0816be 100644 --- a/pkg/hooks/worker.go +++ b/pkg/hooks/worker.go @@ -251,8 +251,8 @@ func POST(port int, path string, body interface{}) (map[string]interface{}, erro var result map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - // Not all endpoints return JSON - return nil, nil + // Not all endpoints return JSON body - return empty map for success with no body + return map[string]interface{}{}, nil } return result, nil