Files
claude-mnemonic/internal/worker/middleware_test.go
T
lukaszraczylo d04b60517a Make things 'betterer' across the board (#23)
* Make things 'betterer' across the board

* fix: reorganize struct fields and config parameters for consistency

- [x] Reorder Config struct fields alphabetically and by related functionality
- [x] Reorganize Observation model fields with archival fields grouped together
- [x] Reorder ObservationStore fields to group related members
- [x] Reorder Store struct fields with health check caching grouped
- [x] Reorganize HealthInfo and PoolMetrics struct field order
- [x] Reorder maintenance Service struct fields logically
- [x] Reorganize MCP server handler parameter structs alphabetically
- [x] Reorder pattern detector candidate tracking fields
- [x] Reorganize search Manager struct fields by functionality
- [x] Reorder vector Client struct fields with mutex protections grouped
- [x] Reorganize handler request/response struct fields
- [x] Update handlers_test.go to expect wrapped response format
- [x] Reorder middleware TokenAuth and rate limiter fields
- [x] Reorganize Service struct fields with grouped functionality
- [x] Fix RateLimiter field ordering for clarity
- [x] Reorder CircuitBreaker metrics fields

* fix(security): improve JSON output safety and path traversal protection

- [x] Replace unsafe JSON string formatting with proper json.Marshal in export handler
- [x] Remove escapeJSONString helper function in favor of standard JSON marshaling
- [x] Add safeResolvePath function to validate paths and prevent directory traversal
- [x] Apply path traversal validation in captureFileMtimes operations
- [x] Cap result slice capacity in getRecentSearchQueries to prevent DoS via excessive allocation

* fix(sdk): improve path traversal protection and allocation safety

- [x] Enhance safeResolvePath with stricter validation using filepath.Rel
- [x] Reject paths containing ".." after cleaning to prevent traversal
- [x] Validate absolute paths are within cwd when cwd is specified
- [x] Apply safeResolvePath validation to GetFileContent for consistency
- [x] Add comprehensive test coverage for path traversal protection
- [x] Fix allocation safety in getRecentSearchQueries by using constant capacity
2026-01-11 01:51:20 +00:00

516 lines
13 KiB
Go

package worker
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check all security headers are set
tests := []struct {
header string
expected string
}{
{"X-Frame-Options", "DENY"},
{"X-Content-Type-Options", "nosniff"},
{"X-XSS-Protection", "1; mode=block"},
{"Referrer-Policy", "strict-origin-when-cross-origin"},
}
for _, tt := range tests {
if got := rr.Header().Get(tt.header); got != tt.expected {
t.Errorf("SecurityHeaders() %s = %q, want %q", tt.header, got, tt.expected)
}
}
}
func TestSecurityHeaders_CORS(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
origin string
expectedOrigin string
expectCORS bool
}{
{
name: "localhost:37778 origin allowed",
origin: "http://localhost:37778",
expectCORS: true,
expectedOrigin: "http://localhost:37778",
},
{
name: "127.0.0.1:5173 origin allowed",
origin: "http://127.0.0.1:5173",
expectCORS: true,
expectedOrigin: "http://127.0.0.1:5173",
},
{
name: "localhost without port allowed",
origin: "http://localhost",
expectCORS: true,
expectedOrigin: "http://localhost",
},
{
name: "external origin blocked",
origin: "http://evil.com",
expectCORS: false,
},
{
name: "evil-localhost.com bypass attempt blocked",
origin: "http://evil-localhost.com",
expectCORS: false,
},
{
name: "localhost subdomain bypass attempt blocked",
origin: "http://localhost.evil.com",
expectCORS: false,
},
{
name: "unknown localhost port blocked",
origin: "http://localhost:9999",
expectCORS: false,
},
{
name: "no origin header",
origin: "",
expectCORS: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
cors := rr.Header().Get("Access-Control-Allow-Origin")
if tt.expectCORS {
if cors != tt.expectedOrigin {
t.Errorf("Expected CORS origin %q, got %q", tt.expectedOrigin, cors)
}
} else {
if cors != "" {
t.Errorf("Expected no CORS header, got %q", cors)
}
}
})
}
}
func TestMaxBodySize(t *testing.T) {
maxSize := int64(100)
handler := MaxBodySize(maxSize)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
contentLength int64
expectedStatus int
}{
{
name: "within limit",
contentLength: 50,
expectedStatus: http.StatusOK,
},
{
name: "at limit",
contentLength: 100,
expectedStatus: http.StatusOK,
},
{
name: "exceeds limit",
contentLength: 150,
expectedStatus: http.StatusRequestEntityTooLarge,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/test", nil)
req.ContentLength = tt.contentLength
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("MaxBodySize() status = %d, want %d", rr.Code, tt.expectedStatus)
}
})
}
}
func TestTokenAuth(t *testing.T) {
t.Run("disabled auth allows all requests", func(t *testing.T) {
ta, err := NewTokenAuth(false)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with disabled auth, got %d", rr.Code)
}
})
t.Run("enabled auth requires token", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Without token
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("Expected Unauthorized without token, got %d", rr.Code)
}
// With correct token in X-Auth-Token header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Auth-Token", ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with correct token, got %d", rr.Code)
}
// With correct token in Authorization header
req = httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+ta.Token())
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK with Bearer token, got %d", rr.Code)
}
})
t.Run("exempt paths skip auth", func(t *testing.T) {
ta, err := NewTokenAuth(true)
if err != nil {
t.Fatalf("NewTokenAuth() error = %v", err)
}
handler := ta.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
exemptPaths := []string{"/health", "/api/health", "/api/ready"}
for _, path := range exemptPaths {
req := httptest.NewRequest("GET", path, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected OK for exempt path %s, got %d", path, rr.Code)
}
}
})
}
func TestExpensiveOperationLimiter(t *testing.T) {
limiter := NewExpensiveOperationLimiter()
// First rebuild should be allowed
if !limiter.CanRebuild() {
t.Error("First rebuild should be allowed")
}
// Immediate second rebuild should be blocked
if limiter.CanRebuild() {
t.Error("Immediate second rebuild should be blocked")
}
}
func TestRequestID(t *testing.T) {
handler := RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request ID is in context
id := GetRequestID(r.Context())
if id == "" {
t.Error("Request ID should be set in context")
}
w.WriteHeader(http.StatusOK)
}))
t.Run("generates new request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") == "" {
t.Error("X-Request-ID header should be set")
}
})
t.Run("uses existing request ID", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Request-ID", "test-id-12345")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Header().Get("X-Request-ID") != "test-id-12345" {
t.Errorf("Expected X-Request-ID to be test-id-12345, got %s", rr.Header().Get("X-Request-ID"))
}
})
}
func TestRequireJSONContentType(t *testing.T) {
handler := RequireJSONContentType(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
method string
contentType string
expectedStatus int
}{
{
name: "GET request without content-type",
method: "GET",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json",
method: "POST",
contentType: "application/json",
expectedStatus: http.StatusOK,
},
{
name: "POST with application/json; charset=utf-8",
method: "POST",
contentType: "application/json; charset=utf-8",
expectedStatus: http.StatusOK,
},
{
name: "POST without content-type (empty body)",
method: "POST",
contentType: "",
expectedStatus: http.StatusOK,
},
{
name: "POST with text/plain rejected",
method: "POST",
contentType: "text/plain",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PUT with application/xml rejected",
method: "PUT",
contentType: "application/xml",
expectedStatus: http.StatusUnsupportedMediaType,
},
{
name: "PATCH with form-urlencoded rejected",
method: "PATCH",
contentType: "application/x-www-form-urlencoded",
expectedStatus: http.StatusUnsupportedMediaType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/test", nil)
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
}
})
}
}
func TestValidateProjectName(t *testing.T) {
tests := []struct {
name string
project string
wantError bool
}{
{
name: "empty project allowed",
project: "",
wantError: false,
},
{
name: "simple project name",
project: "my-project",
wantError: false,
},
{
name: "project with path",
project: "org/my-project",
wantError: false,
},
{
name: "project with underscore",
project: "my_project_v2",
wantError: false,
},
{
name: "project with dot",
project: "my.project.name",
wantError: false,
},
{
name: "path traversal attack",
project: "../../../etc/passwd",
wantError: true,
},
{
name: "hidden path traversal",
project: "project/../../secret",
wantError: true,
},
{
name: "shell injection attempt",
project: "project; rm -rf /",
wantError: true,
},
{
name: "backtick injection",
project: "project`whoami`",
wantError: true,
},
{
name: "special characters",
project: "project$HOME",
wantError: true,
},
{
name: "too long project name",
project: string(make([]byte, 501)),
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateProjectName(tt.project)
if tt.wantError && err == nil {
t.Errorf("Expected error for project %q, got nil", tt.project)
}
if !tt.wantError && err != nil {
t.Errorf("Unexpected error for project %q: %v", tt.project, err)
}
})
}
}
func TestBulkOperationLimiter(t *testing.T) {
limiter := NewBulkOperationLimiter(1) // 1 second cooldown for testing
// First operation should be allowed
if !limiter.CanExecute() {
t.Error("First bulk operation should be allowed")
}
// Immediate second operation should be blocked
if limiter.CanExecute() {
t.Error("Immediate second bulk operation should be blocked")
}
// Check cooldown remaining
remaining := limiter.CooldownRemaining()
if remaining <= 0 || remaining > 1 {
t.Errorf("Expected cooldown remaining between 0-1 seconds, got %d", remaining)
}
// Check time since last op
since := limiter.TimeSinceLastOp()
if since < 0 || since > 1 {
t.Errorf("Expected time since last op between 0-1 seconds, got %d", since)
}
}
func TestSecurityHeaders_CSP(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// Check CSP header is set
csp := rr.Header().Get("Content-Security-Policy")
if csp == "" {
t.Error("Content-Security-Policy header should be set")
}
if csp != "default-src 'self'" {
t.Errorf("Expected CSP to be \"default-src 'self'\", got %q", csp)
}
// Check Permissions-Policy header
pp := rr.Header().Get("Permissions-Policy")
if pp == "" {
t.Error("Permissions-Policy header should be set")
}
}
func TestSecurityHeaders_Preflight(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("OPTIONS", "/test", nil)
req.Header.Set("Origin", "http://localhost:3000")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
// OPTIONS should return 204 No Content
if rr.Code != http.StatusNoContent {
t.Errorf("Expected status 204 for OPTIONS, got %d", rr.Code)
}
// CORS headers should be set for allowed origin
if rr.Header().Get("Access-Control-Allow-Origin") != "http://localhost:3000" {
t.Errorf("CORS origin should be set for allowed origin")
}
if rr.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Access-Control-Allow-Methods should be set")
}
}