mirror of
https://github.com/lukaszraczylo/claude-mnemonic.git
synced 2026-06-05 23:03:55 +00:00
mnemonic ralphised (#24)
* Make things 'betterer' across the board * fix: reorganize struct fields and config parameters for consistency - [x] Reorder Config struct fields alphabetically and by related functionality - [x] Reorganize Observation model fields with archival fields grouped together - [x] Reorder ObservationStore fields to group related members - [x] Reorder Store struct fields with health check caching grouped - [x] Reorganize HealthInfo and PoolMetrics struct field order - [x] Reorder maintenance Service struct fields logically - [x] Reorganize MCP server handler parameter structs alphabetically - [x] Reorder pattern detector candidate tracking fields - [x] Reorganize search Manager struct fields by functionality - [x] Reorder vector Client struct fields with mutex protections grouped - [x] Reorganize handler request/response struct fields - [x] Update handlers_test.go to expect wrapped response format - [x] Reorder middleware TokenAuth and rate limiter fields - [x] Reorganize Service struct fields with grouped functionality - [x] Fix RateLimiter field ordering for clarity - [x] Reorder CircuitBreaker metrics fields * fix(security): improve JSON output safety and path traversal protection - [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler - [x] Remove escapeJSONString helper function in favor of standard JSON marshaling - [x] Add safeResolvePath function to validate paths and prevent directory traversal - [x] Apply path traversal validation in captureFileMtimes operations - [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation * fix(sdk): improve path traversal protection and allocation safety - [x] Enhance safeResolvePath with stricter validation using filepath.Rel - [x] Reject paths containing ".." after cleaning to prevent traversal - [x] Validate absolute paths are within cwd when cwd is specified - [x] Apply safeResolvePath validation to GetFileContent for consistency - [x] Add comprehensive test coverage for path traversal protection - [x] Fix allocation safety in getRecentSearchQueries by using constant capacity * feat(dashboard): add graph stats and vector metrics endpoints - [x] Add handleGraphStats endpoint for knowledge graph visualization - [x] Add handleVectorMetrics endpoint for vector database dashboard - [x] Improve update check error handling with JSON response - [x] Register new API routes for graph and vector metrics - [x] Migrate Font Awesome to npm package from CDN - [x] Fix observations API response type handling - [x] Update package version to v0.10.5-15-g385d05a * fixup! feat(dashboard): add graph stats and vector metrics endpoints * test: add comprehensive test coverage across multiple packages - [x] Add 298 tests for Python chunker functionality - [x] Add 213 tests for chunking types and constants - [x] Add 398 tests for TypeScript/JavaScript chunker - [x] Add 954 tests for MCP server handlers and validation - [x] Add 563 tests for pattern detector and analysis - [x] Add 1149 tests for vector client cache and operations - [x] Add 663 tests for SDK processor, circuit breaker, and deduplication - [x] Add 731 tests for session manager lifecycle and concurrency - [x] Add 331 tests for similarity clustering and term extraction * fix(pattern): add nil check and fmt import for GetPatternInsight - [x] Add `fmt` import for error formatting - [x] Add nil check for pattern before using it - [x] Remove duplicate comment line
This commit is contained in:
@@ -593,3 +593,118 @@ func (s *Service) handleGetObservationByID(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
writeJSON(w, obs)
|
||||
}
|
||||
|
||||
// handleGraphStats returns graph statistics for the dashboard.
|
||||
// Uses relation data to compute knowledge graph metrics.
|
||||
func (s *Service) handleGraphStats(w http.ResponseWriter, r *http.Request) {
|
||||
// Get relation count (edges) - this represents the knowledge graph
|
||||
edgeCount, err := s.relationStore.GetTotalRelationCount(r.Context())
|
||||
if err != nil {
|
||||
edgeCount = 0
|
||||
}
|
||||
|
||||
// Count by relation type
|
||||
edgeTypes := make(map[string]int)
|
||||
for _, t := range models.AllRelationTypes {
|
||||
relations, err := s.relationStore.GetRelationsByType(r.Context(), t, 10000)
|
||||
if err == nil {
|
||||
edgeTypes[string(t)] = len(relations)
|
||||
}
|
||||
}
|
||||
|
||||
// Get unique observation IDs involved in relations (approximate node count)
|
||||
// For now, use edge count as a proxy - each edge has 2 nodes
|
||||
nodeCount := 0
|
||||
if edgeCount > 0 {
|
||||
// Rough estimate: unique nodes ≈ edges * 1.5 (since nodes can have multiple edges)
|
||||
nodeCount = int(float64(edgeCount) * 1.5)
|
||||
}
|
||||
|
||||
// Calculate average degree
|
||||
var avgDegree float64
|
||||
if nodeCount > 0 {
|
||||
avgDegree = float64(edgeCount*2) / float64(nodeCount)
|
||||
}
|
||||
|
||||
// Graph is enabled if we have any edges (relations)
|
||||
enabled := edgeCount > 0
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"enabled": enabled,
|
||||
"nodeCount": nodeCount,
|
||||
"edgeCount": edgeCount,
|
||||
"avgDegree": avgDegree,
|
||||
"maxDegree": 0,
|
||||
"minDegree": 0,
|
||||
"medianDegree": 0.0,
|
||||
"edgeTypes": edgeTypes,
|
||||
"config": map[string]any{
|
||||
"maxHops": 2,
|
||||
"branchFactor": 10,
|
||||
"edgeWeight": 0.3,
|
||||
"rebuildIntervalMin": 30,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// handleVectorMetrics returns vector database metrics for the dashboard.
|
||||
// Returns enabled: false if vector features are not available.
|
||||
func (s *Service) handleVectorMetrics(w http.ResponseWriter, r *http.Request) {
|
||||
if s.vectorClient == nil {
|
||||
writeJSON(w, map[string]any{
|
||||
"enabled": false,
|
||||
"message": "Vector database not initialized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get cache stats from vector client
|
||||
cacheSize, cacheMax := s.vectorClient.CacheStats()
|
||||
cacheStats := s.vectorClient.GetCacheStats()
|
||||
count, _ := s.vectorClient.Count(r.Context())
|
||||
|
||||
uptime := time.Since(s.startTime).Round(time.Second).String()
|
||||
|
||||
// Calculate total queries from cache hits/misses
|
||||
totalQueries := cacheStats.EmbeddingHits + cacheStats.EmbeddingMisses + cacheStats.ResultHits + cacheStats.ResultMisses
|
||||
totalHits := cacheStats.EmbeddingHits + cacheStats.ResultHits
|
||||
totalMisses := cacheStats.EmbeddingMisses + cacheStats.ResultMisses
|
||||
|
||||
writeJSON(w, map[string]any{
|
||||
"enabled": true,
|
||||
"queries": map[string]any{
|
||||
"total": totalQueries,
|
||||
"hubOnly": 0,
|
||||
"hybrid": 0,
|
||||
"onDemand": 0,
|
||||
"graph": 0,
|
||||
},
|
||||
"latency": map[string]any{
|
||||
"avg": "0ms",
|
||||
"p50": "0ms",
|
||||
"p95": "0ms",
|
||||
"p99": "0ms",
|
||||
"avgHub": "0ms",
|
||||
"avgRecompute": "0ms",
|
||||
},
|
||||
"storage": map[string]any{
|
||||
"totalDocuments": count,
|
||||
"hubDocuments": 0,
|
||||
"storedEmbeddings": count,
|
||||
"savingsPercent": 0.0,
|
||||
"recomputedTotal": 0,
|
||||
},
|
||||
"cache": map[string]any{
|
||||
"hits": totalHits,
|
||||
"misses": totalMisses,
|
||||
"hitRate": cacheStats.HitRate(),
|
||||
"size": cacheSize,
|
||||
"maxSize": cacheMax,
|
||||
},
|
||||
"graph": map[string]any{
|
||||
"traversals": 0,
|
||||
"avgDepth": 0.0,
|
||||
},
|
||||
"uptime": uptime,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package worker
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -13,7 +14,14 @@ import (
|
||||
func (s *Service) handleUpdateCheck(w http.ResponseWriter, r *http.Request) {
|
||||
info, err := s.updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// Return a proper JSON response for errors instead of 500
|
||||
// This allows the frontend to handle it gracefully
|
||||
writeJSON(w, map[string]any{
|
||||
"available": false,
|
||||
"current_version": s.version,
|
||||
"error": err.Error(),
|
||||
"rate_limited": strings.Contains(err.Error(), "403"),
|
||||
})
|
||||
return
|
||||
}
|
||||
writeJSON(w, info)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/claude-mnemonic/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -1178,3 +1181,663 @@ func TestSafeResolvePath(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR CircuitBreaker
|
||||
// =============================================================================
|
||||
|
||||
func TestNewCircuitBreaker(t *testing.T) {
|
||||
cb := NewCircuitBreaker(5, 60)
|
||||
|
||||
assert.NotNil(t, cb)
|
||||
assert.Equal(t, int64(5), cb.threshold)
|
||||
assert.Equal(t, int64(60), cb.resetTimeout)
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Allow_Closed(t *testing.T) {
|
||||
cb := NewCircuitBreaker(5, 60)
|
||||
|
||||
// Closed state should allow requests
|
||||
assert.True(t, cb.Allow())
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Allow_Open(t *testing.T) {
|
||||
cb := NewCircuitBreaker(2, 60) // Low threshold for testing
|
||||
|
||||
// Record enough failures to open the circuit
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
|
||||
// Open state should block requests
|
||||
assert.False(t, cb.Allow())
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_RecordSuccess(t *testing.T) {
|
||||
cb := NewCircuitBreaker(2, 60)
|
||||
|
||||
// Record a failure
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, int64(1), cb.Metrics().Failures)
|
||||
|
||||
// Record success resets failures
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, int64(0), cb.Metrics().Failures)
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_RecordFailure_OpensCircuit(t *testing.T) {
|
||||
cb := NewCircuitBreaker(3, 60)
|
||||
|
||||
// Record failures below threshold
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
// Third failure should open circuit
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_State(t *testing.T) {
|
||||
cb := NewCircuitBreaker(1, 60)
|
||||
|
||||
// Initially closed
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
// After failure, open
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
// After success, closed
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Metrics(t *testing.T) {
|
||||
cb := NewCircuitBreaker(5, 120)
|
||||
|
||||
metrics := cb.Metrics()
|
||||
assert.Equal(t, "closed", metrics.State)
|
||||
assert.Equal(t, int64(0), metrics.Failures)
|
||||
assert.Equal(t, int64(5), metrics.Threshold)
|
||||
assert.Equal(t, int64(120), metrics.ResetTimeoutSecs)
|
||||
assert.Equal(t, int64(0), metrics.LastFailureUnix)
|
||||
|
||||
// After failure
|
||||
cb.RecordFailure()
|
||||
metrics = cb.Metrics()
|
||||
assert.Equal(t, int64(1), metrics.Failures)
|
||||
assert.Greater(t, metrics.LastFailureUnix, int64(0))
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Metrics_OpenWithReset(t *testing.T) {
|
||||
cb := NewCircuitBreaker(1, 60)
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
metrics := cb.Metrics()
|
||||
assert.Equal(t, "open", metrics.State)
|
||||
assert.Greater(t, metrics.SecondsUntilReset, int64(0))
|
||||
assert.LessOrEqual(t, metrics.SecondsUntilReset, int64(60))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR RequestDeduplicator
|
||||
// =============================================================================
|
||||
|
||||
func TestNewRequestDeduplicator(t *testing.T) {
|
||||
d := NewRequestDeduplicator(300, 1000)
|
||||
|
||||
assert.NotNil(t, d)
|
||||
assert.NotNil(t, d.seen)
|
||||
assert.Equal(t, int64(300), d.ttlSecs)
|
||||
assert.Equal(t, 1000, d.maxSize)
|
||||
}
|
||||
|
||||
func TestRequestDeduplicator_IsDuplicate_NotSeen(t *testing.T) {
|
||||
d := NewRequestDeduplicator(300, 1000)
|
||||
|
||||
// New hash is not a duplicate
|
||||
assert.False(t, d.IsDuplicate("newhash"))
|
||||
}
|
||||
|
||||
func TestRequestDeduplicator_IsDuplicate_AfterRecord(t *testing.T) {
|
||||
d := NewRequestDeduplicator(300, 1000)
|
||||
|
||||
hash := "testhash"
|
||||
|
||||
// Record the hash
|
||||
d.Record(hash)
|
||||
|
||||
// Now it should be a duplicate
|
||||
assert.True(t, d.IsDuplicate(hash))
|
||||
}
|
||||
|
||||
func TestRequestDeduplicator_Record(t *testing.T) {
|
||||
d := NewRequestDeduplicator(300, 1000)
|
||||
|
||||
hash := "recordtest"
|
||||
d.Record(hash)
|
||||
|
||||
// Check it was recorded
|
||||
d.mu.RLock()
|
||||
_, exists := d.seen[hash]
|
||||
d.mu.RUnlock()
|
||||
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestRequestDeduplicator_Record_Eviction(t *testing.T) {
|
||||
// Small maxSize for testing eviction
|
||||
d := NewRequestDeduplicator(0, 2) // TTL of 0 means everything is "old"
|
||||
|
||||
// Record until capacity
|
||||
d.Record("hash1")
|
||||
d.Record("hash2")
|
||||
|
||||
// Recording a third should trigger eviction (since TTL is 0)
|
||||
d.Record("hash3")
|
||||
|
||||
// Should have cleaned up old entries
|
||||
d.mu.RLock()
|
||||
size := len(d.seen)
|
||||
d.mu.RUnlock()
|
||||
|
||||
// Size should be limited (eviction occurred)
|
||||
assert.LessOrEqual(t, size, 3)
|
||||
}
|
||||
|
||||
func TestHashRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
input string
|
||||
output string
|
||||
compareWith []string
|
||||
wantLen int
|
||||
wantSame bool
|
||||
}{
|
||||
{
|
||||
name: "basic hash",
|
||||
toolName: "Read",
|
||||
input: "file.txt",
|
||||
output: "content",
|
||||
wantLen: 16,
|
||||
},
|
||||
{
|
||||
name: "consistent hashing",
|
||||
toolName: "Edit",
|
||||
input: "same input",
|
||||
output: "same output",
|
||||
wantLen: 16,
|
||||
},
|
||||
{
|
||||
name: "long output truncation",
|
||||
toolName: "Bash",
|
||||
input: "command",
|
||||
output: string(make([]byte, 5000)), // Very long output
|
||||
wantLen: 16,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash := hashRequest(tt.toolName, tt.input, tt.output)
|
||||
assert.Len(t, hash, tt.wantLen)
|
||||
|
||||
// Same inputs should produce same hash
|
||||
hash2 := hashRequest(tt.toolName, tt.input, tt.output)
|
||||
assert.Equal(t, hash, hash2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashRequest_DifferentInputs(t *testing.T) {
|
||||
// Different inputs should produce different hashes
|
||||
hash1 := hashRequest("Read", "file1.txt", "content1")
|
||||
hash2 := hashRequest("Read", "file2.txt", "content2")
|
||||
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestHashRequest_OutputTruncation(t *testing.T) {
|
||||
// Hash should be the same for outputs that differ only after 1000 chars
|
||||
longOutput1 := string(make([]byte, 1500))
|
||||
longOutput2 := longOutput1[:1000] + "different suffix here"
|
||||
|
||||
hash1 := hashRequest("Read", "input", longOutput1)
|
||||
hash2 := hashRequest("Read", "input", longOutput2)
|
||||
|
||||
// Since we only hash first 1000 chars, these should be the same
|
||||
assert.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR Processor methods
|
||||
// =============================================================================
|
||||
|
||||
func TestProcessor_CircuitBreakerState(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(2, 60),
|
||||
}
|
||||
|
||||
// Initially closed
|
||||
assert.Equal(t, "closed", p.CircuitBreakerState())
|
||||
|
||||
// After enough failures, open
|
||||
p.circuitBreaker.RecordFailure()
|
||||
p.circuitBreaker.RecordFailure()
|
||||
assert.Equal(t, "open", p.CircuitBreakerState())
|
||||
|
||||
// After success, closed
|
||||
p.circuitBreaker.RecordSuccess()
|
||||
assert.Equal(t, "closed", p.CircuitBreakerState())
|
||||
}
|
||||
|
||||
func TestProcessor_CircuitBreakerMetrics(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 120),
|
||||
}
|
||||
|
||||
metrics := p.CircuitBreakerMetrics()
|
||||
assert.Equal(t, "closed", metrics.State)
|
||||
assert.Equal(t, int64(0), metrics.Failures)
|
||||
assert.Equal(t, int64(5), metrics.Threshold)
|
||||
assert.Equal(t, int64(120), metrics.ResetTimeoutSecs)
|
||||
|
||||
// Record a failure and check metrics update
|
||||
p.circuitBreaker.RecordFailure()
|
||||
metrics = p.CircuitBreakerMetrics()
|
||||
assert.Equal(t, int64(1), metrics.Failures)
|
||||
assert.Greater(t, metrics.LastFailureUnix, int64(0))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR Vector Sync Workers
|
||||
// =============================================================================
|
||||
|
||||
func TestProcessor_StartAndStopVectorSyncWorkers(t *testing.T) {
|
||||
var syncedObservations []*models.Observation
|
||||
var mu sync.Mutex
|
||||
|
||||
p := &Processor{
|
||||
vectorSyncChan: make(chan *models.Observation, MaxVectorSyncWorkers*2),
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
syncObservationFunc: func(obs *models.Observation) {
|
||||
mu.Lock()
|
||||
syncedObservations = append(syncedObservations, obs)
|
||||
mu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
// Start workers
|
||||
p.StartVectorSyncWorkers()
|
||||
|
||||
// Send some observations
|
||||
obs1 := &models.Observation{SDKSessionID: "test1"}
|
||||
obs2 := &models.Observation{SDKSessionID: "test2"}
|
||||
p.vectorSyncChan <- obs1
|
||||
p.vectorSyncChan <- obs2
|
||||
|
||||
// Give workers time to process
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop workers
|
||||
p.StopVectorSyncWorkers()
|
||||
|
||||
// Verify observations were synced
|
||||
mu.Lock()
|
||||
assert.Len(t, syncedObservations, 2)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func TestProcessor_VectorSyncWorker_DrainOnShutdown(t *testing.T) {
|
||||
var syncedCount int
|
||||
var mu sync.Mutex
|
||||
|
||||
p := &Processor{
|
||||
vectorSyncChan: make(chan *models.Observation, 10),
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
syncObservationFunc: func(obs *models.Observation) {
|
||||
mu.Lock()
|
||||
syncedCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
}
|
||||
|
||||
// Queue observations before starting workers
|
||||
for i := 0; i < 5; i++ {
|
||||
p.vectorSyncChan <- &models.Observation{SDKSessionID: "pre-queued"}
|
||||
}
|
||||
|
||||
// Start workers
|
||||
p.StartVectorSyncWorkers()
|
||||
|
||||
// Stop immediately - workers should drain the queue
|
||||
p.StopVectorSyncWorkers()
|
||||
|
||||
// All pre-queued items should have been processed
|
||||
mu.Lock()
|
||||
assert.Equal(t, 5, syncedCount)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func TestProcessor_VectorSyncWorker_NilSyncFunc(t *testing.T) {
|
||||
p := &Processor{
|
||||
vectorSyncChan: make(chan *models.Observation, 10),
|
||||
vectorSyncDone: make(chan struct{}),
|
||||
syncObservationFunc: nil, // No sync function set
|
||||
}
|
||||
|
||||
// Start workers
|
||||
p.StartVectorSyncWorkers()
|
||||
|
||||
// Send observation - should not panic even with nil sync func
|
||||
p.vectorSyncChan <- &models.Observation{SDKSessionID: "test"}
|
||||
|
||||
// Give it time to process
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop workers - should not panic
|
||||
p.StopVectorSyncWorkers()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR CircuitBreaker Additional Behaviors
|
||||
// =============================================================================
|
||||
|
||||
func TestCircuitBreaker_Allow_OpenBlocksRequests(t *testing.T) {
|
||||
cb := NewCircuitBreaker(1, 60)
|
||||
|
||||
// Open the circuit
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
// All requests should be blocked
|
||||
assert.False(t, cb.Allow())
|
||||
assert.False(t, cb.Allow())
|
||||
assert.False(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_MultipleFailures(t *testing.T) {
|
||||
cb := NewCircuitBreaker(3, 60) // Higher threshold
|
||||
|
||||
// Record failures below threshold
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.Equal(t, int64(1), cb.Metrics().Failures)
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.Equal(t, int64(2), cb.Metrics().Failures)
|
||||
|
||||
// Third failure opens circuit
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
assert.Equal(t, int64(3), cb.Metrics().Failures)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
cb := NewCircuitBreaker(5, 60)
|
||||
|
||||
// Record some failures
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, int64(2), cb.Metrics().Failures)
|
||||
|
||||
// Success resets failures
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, int64(0), cb.Metrics().Failures)
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Metrics_Comprehensive(t *testing.T) {
|
||||
cb := NewCircuitBreaker(5, 120)
|
||||
|
||||
// Initial state
|
||||
metrics := cb.Metrics()
|
||||
assert.Equal(t, "closed", metrics.State)
|
||||
assert.Equal(t, int64(0), metrics.Failures)
|
||||
assert.Equal(t, int64(5), metrics.Threshold)
|
||||
assert.Equal(t, int64(120), metrics.ResetTimeoutSecs)
|
||||
assert.Equal(t, int64(0), metrics.LastFailureUnix)
|
||||
assert.Equal(t, int64(0), metrics.SecondsUntilReset)
|
||||
|
||||
// After failures that open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.RecordFailure()
|
||||
}
|
||||
metrics = cb.Metrics()
|
||||
assert.Equal(t, "open", metrics.State)
|
||||
assert.Equal(t, int64(5), metrics.Failures)
|
||||
assert.Greater(t, metrics.LastFailureUnix, int64(0))
|
||||
assert.Greater(t, metrics.SecondsUntilReset, int64(0))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR MaxVectorSyncWorkers constant
|
||||
// =============================================================================
|
||||
|
||||
func TestMaxVectorSyncWorkers(t *testing.T) {
|
||||
assert.Equal(t, 8, MaxVectorSyncWorkers)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ADDITIONAL EDGE CASE TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestRequestDeduplicator_IsDuplicate_ExpiredEntry(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping time-dependent test in short mode")
|
||||
}
|
||||
// Use a 1-second TTL with enough margin
|
||||
d := NewRequestDeduplicator(1, 100)
|
||||
|
||||
hash := "expiretest"
|
||||
d.Record(hash)
|
||||
|
||||
// Initially duplicate
|
||||
assert.True(t, d.IsDuplicate(hash))
|
||||
|
||||
// Wait for TTL to expire (2.5 seconds to ensure crossing second boundaries)
|
||||
time.Sleep(2500 * time.Millisecond)
|
||||
|
||||
// Should no longer be considered duplicate
|
||||
assert.False(t, d.IsDuplicate(hash))
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR ProcessObservation Early Returns
|
||||
// =============================================================================
|
||||
|
||||
func TestProcessObservation_SkipTool(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 60),
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// TodoWrite should be skipped
|
||||
err := p.ProcessObservation(ctx, "session-1", "project-1", "TodoWrite",
|
||||
map[string]string{"content": "test"}, "success", 1, "/test/cwd")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Glob should be skipped
|
||||
err = p.ProcessObservation(ctx, "session-1", "project-1", "Glob",
|
||||
map[string]string{"pattern": "*.go"}, []string{"main.go", "test.go"}, 1, "/test/cwd")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// AskUserQuestion should be skipped
|
||||
err = p.ProcessObservation(ctx, "session-1", "project-1", "AskUserQuestion",
|
||||
"question", "answer", 1, "/test/cwd")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProcessObservation_SkipTrivial(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 60),
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Short output should be skipped
|
||||
err := p.ProcessObservation(ctx, "session-1", "project-1", "Read",
|
||||
map[string]string{"file_path": "/test.go"}, "short", 1, "/test/cwd")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// "No matches found" should be skipped
|
||||
err = p.ProcessObservation(ctx, "session-1", "project-1", "Grep",
|
||||
map[string]string{"pattern": "test"}, "No matches found in the repository", 1, "/test/cwd")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProcessObservation_SkipDuplicate(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 60),
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
sem: make(chan struct{}, 4),
|
||||
claudePath: "/nonexistent/path", // Will fail at CLI call stage
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Valid input that would be processed
|
||||
input := map[string]string{"file_path": "/project/main.go"}
|
||||
output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}"
|
||||
|
||||
// First call should try to process (will fail because claudePath doesn't exist)
|
||||
err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd")
|
||||
// Expect error because claudePath doesn't exist
|
||||
assert.Error(t, err)
|
||||
|
||||
// Second call with same input should be skipped as duplicate
|
||||
err = p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd")
|
||||
assert.NoError(t, err) // No error because it was skipped as duplicate
|
||||
}
|
||||
|
||||
func TestProcessObservation_CircuitBreakerOpen(t *testing.T) {
|
||||
cb := NewCircuitBreaker(1, 60) // Threshold of 1
|
||||
cb.RecordFailure() // Open the circuit breaker
|
||||
|
||||
p := &Processor{
|
||||
circuitBreaker: cb,
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Valid input that would be processed
|
||||
input := map[string]string{"file_path": "/project/main.go"}
|
||||
output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}"
|
||||
|
||||
err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "circuit breaker open")
|
||||
}
|
||||
|
||||
func TestProcessObservation_ContextCancel(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 60),
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
sem: make(chan struct{}, 1), // Small semaphore
|
||||
claudePath: "/fake/claude",
|
||||
}
|
||||
|
||||
// Fill the semaphore
|
||||
p.sem <- struct{}{}
|
||||
|
||||
// Create a cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
// Valid input that would be processed
|
||||
input := map[string]string{"file_path": "/project/main.go"}
|
||||
output := "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello World\")\n}"
|
||||
|
||||
err := p.ProcessObservation(ctx, "session-1", "project-1", "Read", input, output, 1, "/test/cwd")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR ProcessSummary Early Returns
|
||||
// =============================================================================
|
||||
|
||||
func TestProcessSummary_SkipEmptyRequest(t *testing.T) {
|
||||
p := &Processor{
|
||||
circuitBreaker: NewCircuitBreaker(5, 60),
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty request should be skipped (sessionDBID, sdkSessionID, project, userPrompt, lastUserMsg, lastAssistantMsg)
|
||||
err := p.ProcessSummary(ctx, 1, "session-1", "project-1", "", "", "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProcessSummary_CircuitBreakerOpen(t *testing.T) {
|
||||
cb := NewCircuitBreaker(1, 60)
|
||||
cb.RecordFailure() // Open the circuit breaker
|
||||
|
||||
p := &Processor{
|
||||
circuitBreaker: cb,
|
||||
deduplicator: NewRequestDeduplicator(300, 1000),
|
||||
sem: make(chan struct{}, 4),
|
||||
claudePath: "/nonexistent/path",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Meaningful assistant message (> 200 chars, contains code discussion)
|
||||
assistantMsg := `I've updated the handler.go file to fix the authentication bug.
|
||||
The function validateToken() was not checking token expiry correctly.
|
||||
I've added a check for the exp claim and implemented proper error handling.
|
||||
The changes have been tested and the build passes successfully.
|
||||
Here's the implementation details and code review.`
|
||||
|
||||
// Valid request but circuit breaker is open
|
||||
err := p.ProcessSummary(ctx, 1, "session-1", "project-1",
|
||||
"Implement authentication", "User message", assistantMsg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "claude CLI failed")
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR callClaudeCLI Error Paths
|
||||
// =============================================================================
|
||||
|
||||
func TestCallClaudeCLI_PromptTooLarge(t *testing.T) {
|
||||
p := &Processor{
|
||||
claudePath: "/fake/claude",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a prompt that exceeds MaxPromptSize
|
||||
largePrompt := string(make([]byte, MaxPromptSize+1))
|
||||
|
||||
_, err := p.callClaudeCLI(ctx, largePrompt)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prompt exceeds maximum size")
|
||||
}
|
||||
|
||||
func TestCallClaudeCLI_BinaryNotFound(t *testing.T) {
|
||||
p := &Processor{
|
||||
claudePath: "/nonexistent/path/to/claude",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := p.callClaudeCLI(ctx, "test prompt")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "claude CLI failed")
|
||||
}
|
||||
|
||||
@@ -1198,6 +1198,8 @@ func (s *Service) setupRoutes() {
|
||||
// Vector management endpoints
|
||||
s.router.Post("/api/vectors/rebuild", s.handleTriggerVectorRebuild)
|
||||
s.router.Get("/api/vectors/health", s.handleVectorHealth)
|
||||
s.router.Get("/api/vector/metrics", s.handleVectorMetrics)
|
||||
s.router.Get("/api/graph/stats", s.handleGraphStats)
|
||||
|
||||
// Readiness check - returns 200 only when fully initialized
|
||||
s.router.Get("/api/ready", s.handleReady)
|
||||
|
||||
@@ -693,3 +693,734 @@ func TestToolInputResponse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR NewManager AND CLEANUP
|
||||
// =============================================================================
|
||||
|
||||
// TestNewManager tests the NewManager function.
|
||||
func TestNewManager(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test with nil session store (valid for testing)
|
||||
manager := NewManager(nil)
|
||||
|
||||
assert.NotNil(t, manager)
|
||||
assert.NotNil(t, manager.sessions)
|
||||
assert.NotNil(t, manager.ProcessNotify)
|
||||
assert.NotNil(t, manager.ctx)
|
||||
assert.NotNil(t, manager.cancel)
|
||||
assert.Equal(t, 0, manager.GetActiveSessionCount())
|
||||
|
||||
// Clean up - cancel context to stop cleanup goroutine
|
||||
manager.cancel()
|
||||
}
|
||||
|
||||
// TestNewManager_CleanupGoroutineStops tests that cleanup goroutine stops on cancel.
|
||||
func TestNewManager_CleanupGoroutineStops(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil)
|
||||
|
||||
// Give goroutine time to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Cancel should stop the cleanup goroutine
|
||||
manager.cancel()
|
||||
|
||||
// Context should be done
|
||||
select {
|
||||
case <-manager.ctx.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Context should be done after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_NoSessions tests cleanup with no sessions.
|
||||
func TestCleanupStaleSessions_NoSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Should not panic with empty sessions
|
||||
manager.cleanupStaleSessions()
|
||||
assert.Equal(t, 0, manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_FreshSession tests that fresh sessions are not cleaned.
|
||||
func TestCleanupStaleSessions_FreshSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Add a fresh session
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
StartTime: time.Now(), // Fresh
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
manager.cleanupStaleSessions()
|
||||
|
||||
// Session should still exist (not stale)
|
||||
assert.Equal(t, 1, manager.GetActiveSessionCount())
|
||||
sessionCancel()
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_StaleSession tests that stale sessions are cleaned.
|
||||
func TestCleanupStaleSessions_StaleSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Add a stale session
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
manager.cleanupStaleSessions()
|
||||
|
||||
// Session should be deleted
|
||||
assert.Equal(t, 0, manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_StaleWithPending tests stale sessions with pending messages are not cleaned.
|
||||
func TestCleanupStaleSessions_StaleWithPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Add a stale session with pending messages
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
defer sessionCancel()
|
||||
manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale
|
||||
pendingMessages: []PendingMessage{{Type: MessageTypeObservation}},
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
manager.cleanupStaleSessions()
|
||||
|
||||
// Session should NOT be deleted (has pending messages)
|
||||
assert.Equal(t, 1, manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_StaleWithActiveGenerator tests stale sessions with active generator are not cleaned.
|
||||
func TestCleanupStaleSessions_StaleWithActiveGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Add a stale session with active generator
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
defer sessionCancel()
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
StartTime: time.Now().Add(-SessionTimeout - time.Minute), // Stale
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
session.generatorActive.Store(true)
|
||||
manager.sessions[1] = session
|
||||
|
||||
manager.cleanupStaleSessions()
|
||||
|
||||
// Session should NOT be deleted (generator is active)
|
||||
assert.Equal(t, 1, manager.GetActiveSessionCount())
|
||||
}
|
||||
|
||||
// TestCleanupStaleSessions_MixedSessions tests cleanup with mixed fresh and stale sessions.
|
||||
func TestCleanupStaleSessions_MixedSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Fresh session
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
manager.sessions[1] = &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: ctx1,
|
||||
cancel: cancel1,
|
||||
}
|
||||
|
||||
// Stale session (should be deleted)
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
manager.sessions[2] = &ActiveSession{
|
||||
SessionDBID: 2,
|
||||
StartTime: time.Now().Add(-SessionTimeout - time.Minute),
|
||||
pendingMessages: []PendingMessage{},
|
||||
ctx: ctx2,
|
||||
cancel: cancel2,
|
||||
}
|
||||
|
||||
// Stale session with pending (should NOT be deleted)
|
||||
ctx3, cancel3 := context.WithCancel(context.Background())
|
||||
defer cancel3()
|
||||
manager.sessions[3] = &ActiveSession{
|
||||
SessionDBID: 3,
|
||||
StartTime: time.Now().Add(-SessionTimeout - time.Minute),
|
||||
pendingMessages: []PendingMessage{{Type: MessageTypeObservation}},
|
||||
ctx: ctx3,
|
||||
cancel: cancel3,
|
||||
}
|
||||
|
||||
manager.cleanupStaleSessions()
|
||||
|
||||
// Should have 2 sessions left (1 fresh, 1 stale with pending)
|
||||
assert.Equal(t, 2, manager.GetActiveSessionCount())
|
||||
|
||||
// Verify which sessions remain
|
||||
manager.mu.RLock()
|
||||
_, has1 := manager.sessions[1]
|
||||
_, has2 := manager.sessions[2]
|
||||
_, has3 := manager.sessions[3]
|
||||
manager.mu.RUnlock()
|
||||
|
||||
assert.True(t, has1, "Fresh session should remain")
|
||||
assert.False(t, has2, "Stale session should be deleted")
|
||||
assert.True(t, has3, "Stale session with pending should remain")
|
||||
}
|
||||
|
||||
// TestCleanupLoop_ExitsOnCancel tests that cleanup loop exits when context is cancelled.
|
||||
func TestCleanupLoop_ExitsOnCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Start cleanup loop in goroutine
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
manager.cleanupLoop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel immediately
|
||||
cancel()
|
||||
|
||||
// Should exit quickly
|
||||
select {
|
||||
case <-done:
|
||||
// Success - loop exited
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Cleanup loop should exit when context is cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR InitializeSession (without DB)
|
||||
// =============================================================================
|
||||
|
||||
// TestInitializeSession_AlreadyActive tests reusing an already active session.
|
||||
func TestInitializeSession_AlreadyActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add an active session
|
||||
existingSession := &ActiveSession{
|
||||
SessionDBID: 42,
|
||||
ClaudeSessionID: "claude-existing",
|
||||
Project: "test-project",
|
||||
UserPrompt: "original prompt",
|
||||
LastPromptNumber: 1,
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
}
|
||||
manager.sessions[42] = existingSession
|
||||
|
||||
// Initialize same session - should reuse
|
||||
session, err := manager.InitializeSession(context.Background(), 42, "new prompt", 5)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Same(t, existingSession, session)
|
||||
assert.Equal(t, "new prompt", session.UserPrompt)
|
||||
assert.Equal(t, 5, session.LastPromptNumber)
|
||||
}
|
||||
|
||||
// TestInitializeSession_AlreadyActive_EmptyPrompt tests reusing session with empty prompt.
|
||||
func TestInitializeSession_AlreadyActive_EmptyPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add an active session
|
||||
existingSession := &ActiveSession{
|
||||
SessionDBID: 42,
|
||||
UserPrompt: "original prompt",
|
||||
LastPromptNumber: 1,
|
||||
}
|
||||
manager.sessions[42] = existingSession
|
||||
|
||||
// Initialize with empty prompt - should NOT update
|
||||
session, err := manager.InitializeSession(context.Background(), 42, "", 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "original prompt", session.UserPrompt) // Unchanged
|
||||
assert.Equal(t, 1, session.LastPromptNumber) // Unchanged
|
||||
}
|
||||
|
||||
// TestInitializeSession_NoStore tests initialization without session store.
|
||||
func TestInitializeSession_NoStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessionStore: nil, // No store
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Should fail gracefully with nil store (panic recovery not expected)
|
||||
// This tests the guard against nil sessionStore
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
_ = r // Expected panic when calling nil store - intentionally ignored
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = manager.InitializeSession(context.Background(), 999, "prompt", 1)
|
||||
}
|
||||
|
||||
// TestInitializeSession_CallbackTriggered tests that created callback is triggered.
|
||||
func TestInitializeSession_CallbackTriggered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
var calledWithID int64
|
||||
manager.SetOnSessionCreated(func(id int64) {
|
||||
calledWithID = id
|
||||
})
|
||||
|
||||
// Add session directly (simulating what would happen after DB fetch)
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
defer sessionCancel()
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 100,
|
||||
ClaudeSessionID: "test",
|
||||
Project: "project",
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.sessions[100] = session
|
||||
onCreated := manager.onCreated
|
||||
manager.mu.Unlock()
|
||||
|
||||
// Trigger callback
|
||||
if onCreated != nil {
|
||||
onCreated(100)
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(100), calledWithID)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TESTS FOR QueueObservation AND QueueSummarize (without DB)
|
||||
// =============================================================================
|
||||
|
||||
// TestQueueObservation_ToExistingSession tests queuing to an existing session.
|
||||
func TestQueueObservation_ToExistingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Queue observation
|
||||
err := manager.QueueObservation(context.Background(), 1, ObservationData{
|
||||
ToolName: "Read",
|
||||
ToolInput: map[string]string{"path": "/test"},
|
||||
ToolResponse: "content",
|
||||
PromptNumber: 1,
|
||||
CWD: "/project",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, manager.GetTotalQueueDepth())
|
||||
|
||||
// Verify message
|
||||
messages := manager.DrainMessages(1)
|
||||
assert.Len(t, messages, 1)
|
||||
assert.Equal(t, MessageTypeObservation, messages[0].Type)
|
||||
assert.Equal(t, "Read", messages[0].Observation.ToolName)
|
||||
assert.Equal(t, "/project", messages[0].Observation.CWD)
|
||||
}
|
||||
|
||||
// TestQueueObservation_NotifiesSession tests that notification is sent to session.
|
||||
func TestQueueObservation_NotifiesSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session with notify channel
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Queue observation
|
||||
err := manager.QueueObservation(context.Background(), 1, ObservationData{ToolName: "Test"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should receive notification on session channel
|
||||
select {
|
||||
case <-session.notify:
|
||||
// Success
|
||||
default:
|
||||
t.Error("Session should receive notification")
|
||||
}
|
||||
|
||||
// Should receive notification on process channel
|
||||
select {
|
||||
case <-manager.ProcessNotify:
|
||||
// Success
|
||||
default:
|
||||
t.Error("Manager ProcessNotify should receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueSummarize_ToExistingSession tests queuing summarize to an existing session.
|
||||
func TestQueueSummarize_ToExistingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Queue summarize
|
||||
err := manager.QueueSummarize(context.Background(), 1, "User asked question", "Assistant answered")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, manager.GetTotalQueueDepth())
|
||||
|
||||
// Verify message
|
||||
messages := manager.DrainMessages(1)
|
||||
assert.Len(t, messages, 1)
|
||||
assert.Equal(t, MessageTypeSummarize, messages[0].Type)
|
||||
assert.Equal(t, "User asked question", messages[0].Summarize.LastUserMessage)
|
||||
assert.Equal(t, "Assistant answered", messages[0].Summarize.LastAssistantMessage)
|
||||
}
|
||||
|
||||
// TestQueueSummarize_NotifiesSession tests that notification is sent to session.
|
||||
func TestQueueSummarize_NotifiesSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session with notify channel
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Queue summarize
|
||||
err := manager.QueueSummarize(context.Background(), 1, "user", "assistant")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should receive notification on session channel
|
||||
select {
|
||||
case <-session.notify:
|
||||
// Success
|
||||
default:
|
||||
t.Error("Session should receive notification")
|
||||
}
|
||||
|
||||
// Should receive notification on process channel
|
||||
select {
|
||||
case <-manager.ProcessNotify:
|
||||
// Success
|
||||
default:
|
||||
t.Error("Manager ProcessNotify should receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueOperations_MultipleMessages tests queuing multiple messages.
|
||||
func TestQueueOperations_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Queue multiple messages
|
||||
for i := 0; i < 10; i++ {
|
||||
if i%2 == 0 {
|
||||
err := manager.QueueObservation(context.Background(), 1, ObservationData{
|
||||
ToolName: "Tool" + string(rune('A'+i)),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
err := manager.QueueSummarize(context.Background(), 1, "user", "assistant")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, manager.GetTotalQueueDepth())
|
||||
|
||||
// Drain and verify
|
||||
messages := manager.DrainMessages(1)
|
||||
assert.Len(t, messages, 10)
|
||||
}
|
||||
|
||||
// TestQueueOperations_NonBlockingNotification tests non-blocking notification behavior.
|
||||
func TestQueueOperations_NonBlockingNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add session with full notify channel
|
||||
session := &ActiveSession{
|
||||
SessionDBID: 1,
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
// Fill the notify channel
|
||||
session.notify <- struct{}{}
|
||||
manager.sessions[1] = session
|
||||
|
||||
// Fill ProcessNotify channel
|
||||
manager.ProcessNotify <- struct{}{}
|
||||
|
||||
// Queue should NOT block even with full channels
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
err := manager.QueueObservation(context.Background(), 1, ObservationData{ToolName: "Test"})
|
||||
assert.NoError(t, err)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success - didn't block
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Queue operation should not block even with full notification channels")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentQueueAndCleanup tests concurrent queue operations and cleanup.
|
||||
func TestConcurrentQueueAndCleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := &Manager{
|
||||
sessions: make(map[int64]*ActiveSession),
|
||||
ProcessNotify: make(chan struct{}, 1),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
manager.ctx = ctx
|
||||
manager.cancel = cancel
|
||||
|
||||
// Pre-add multiple sessions
|
||||
for i := int64(1); i <= 5; i++ {
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
manager.sessions[i] = &ActiveSession{
|
||||
SessionDBID: i,
|
||||
StartTime: time.Now(),
|
||||
pendingMessages: make([]PendingMessage, 0),
|
||||
notify: make(chan struct{}, 1),
|
||||
ctx: sessionCtx,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent queue operations
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sessionID := int64((idx % 5) + 1)
|
||||
if idx%2 == 0 {
|
||||
_ = manager.QueueObservation(context.Background(), sessionID, ObservationData{ToolName: "Test"})
|
||||
} else {
|
||||
_ = manager.QueueSummarize(context.Background(), sessionID, "user", "assistant")
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent cleanup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
manager.cleanupStaleSessions()
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = manager.GetActiveSessionCount()
|
||||
_ = manager.GetTotalQueueDepth()
|
||||
_ = manager.IsAnySessionProcessing()
|
||||
_ = manager.GetAllSessions()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should have all sessions (none are stale)
|
||||
assert.Equal(t, 5, manager.GetActiveSessionCount())
|
||||
// Should have 50 messages total
|
||||
assert.Equal(t, 50, manager.GetTotalQueueDepth())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user