Skip to content

Commit bcf01a8

Browse files
committed
fix(store): escape FTS query to handle special characters
Adds escaping for the full-text search query in DocumentStore.findByContent to prevent errors and ensure correct search behavior when the query contains special characters like quotes, parentheses, or FTS operators. Also adds unit tests to verify the escaping logic. Fixes #10
1 parent 3763168 commit bcf01a8

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

src/store/DocumentStore.test.ts

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import { afterAll, afterEach, beforeEach, describe, expect, it, vi } from "vitest";
2+
3+
// --- Mocking Setup ---
4+
5+
// Mock OpenAIEmbeddings
6+
const mockEmbedQuery = vi.fn().mockResolvedValue([0.1, 0.2, 0.3]);
7+
const mockEmbedDocuments = vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]); // Keep this if addDocuments is tested elsewhere
8+
9+
// Mock the module to export a mock function for the class constructor
10+
vi.mock("@langchain/openai", () => ({
11+
OpenAIEmbeddings: vi.fn(), // Mock the class export as a vi.fn()
12+
}));
13+
14+
// Mock better-sqlite3
15+
const mockStatementAll = vi.fn().mockReturnValue([]);
16+
// Ensure the mock statement object covers methods used by *all* statements prepared in DocumentStore
17+
const mockStatement = {
18+
all: mockStatementAll,
19+
run: vi.fn().mockReturnValue({ changes: 0, lastInsertRowid: 1 }), // Mock run for insert/delete
20+
get: vi.fn().mockReturnValue(undefined), // Mock get for getById/checkExists etc.
21+
};
22+
const mockPrepare = vi.fn().mockReturnValue(mockStatement);
23+
const mockDb = {
24+
prepare: mockPrepare,
25+
exec: vi.fn(),
26+
transaction: vi.fn((fn) => fn()),
27+
close: vi.fn(),
28+
};
29+
vi.mock("better-sqlite3", () => ({
30+
default: vi.fn(() => mockDb), // Mock the default export (constructor)
31+
}));
32+
33+
// Mock sqlite-vec
34+
vi.mock("sqlite-vec", () => ({
35+
load: vi.fn(),
36+
}));
37+
38+
// --- Test Suite ---
39+
40+
// Import the mocked constructor function
41+
import { OpenAIEmbeddings } from "@langchain/openai";
42+
// Import DocumentStore AFTER mocks are defined
43+
import { DocumentStore } from "./DocumentStore";
44+
45+
// Cast OpenAIEmbeddings to the correct Vitest mock type for configuration
46+
const MockedOpenAIEmbeddingsConstructor = OpenAIEmbeddings as ReturnType<typeof vi.fn>;
47+
48+
describe("DocumentStore", () => {
49+
let documentStore: DocumentStore;
50+
51+
beforeEach(async () => {
52+
vi.clearAllMocks(); // Clear call history etc.
53+
54+
// Configure the mock constructor's implementation for THIS test run
55+
MockedOpenAIEmbeddingsConstructor.mockImplementation(() => ({
56+
embedQuery: mockEmbedQuery,
57+
embedDocuments: mockEmbedDocuments,
58+
}));
59+
mockPrepare.mockReturnValue(mockStatement); // <-- Re-configure prepare mock return value
60+
61+
// Now create the store and initialize.
62+
// initialize() will call 'new OpenAIEmbeddings()', which uses our fresh mock implementation.
63+
documentStore = new DocumentStore(":memory:");
64+
await documentStore.initialize();
65+
});
66+
67+
afterAll(() => {
68+
vi.restoreAllMocks();
69+
});
70+
71+
describe("findByContent", () => {
72+
const library = "test-lib";
73+
const version = "1.0.0";
74+
const limit = 10;
75+
76+
it("should call embedQuery and prepare/all with escaped FTS query for double quotes", async () => {
77+
const query = 'find "quotes"';
78+
const expectedFtsQuery = '"find ""quotes"""'; // Escaped and wrapped
79+
80+
await documentStore.findByContent(library, version, query, limit);
81+
82+
// 1. Check if embedQuery was called
83+
expect(mockEmbedQuery).toHaveBeenCalledWith(query);
84+
expect(mockEmbedQuery).toHaveBeenCalledTimes(1);
85+
86+
// 2. Check if db.prepare was called correctly during findByContent
87+
// It's called multiple times during initialize, so check the specific call
88+
const prepareCall = mockPrepare.mock.calls.find((call) =>
89+
call[0].includes("WITH vec_scores AS"),
90+
);
91+
expect(prepareCall).toBeDefined();
92+
93+
// 3. Check the arguments passed to the statement's 'all' method
94+
expect(mockStatementAll).toHaveBeenCalledTimes(1); // Only the findByContent call should use 'all'
95+
const lastCallArgs = mockStatementAll.mock.lastCall;
96+
expect(lastCallArgs).toEqual([
97+
library.toLowerCase(),
98+
version.toLowerCase(),
99+
expect.any(String), // Embedding JSON
100+
limit,
101+
library.toLowerCase(),
102+
version.toLowerCase(),
103+
expectedFtsQuery, // Check the escaped query string
104+
limit,
105+
]);
106+
});
107+
108+
it("should correctly escape FTS operators", async () => {
109+
const query = "search AND this OR that";
110+
const expectedFtsQuery = '"search AND this OR that"';
111+
await documentStore.findByContent(library, version, query, limit);
112+
expect(mockStatementAll).toHaveBeenCalledTimes(1);
113+
const lastCallArgs = mockStatementAll.mock.lastCall;
114+
expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); // Check only the FTS query argument
115+
});
116+
117+
it("should correctly escape parentheses", async () => {
118+
const query = "function(arg)";
119+
const expectedFtsQuery = '"function(arg)"';
120+
await documentStore.findByContent(library, version, query, limit);
121+
expect(mockStatementAll).toHaveBeenCalledTimes(1);
122+
const lastCallArgs = mockStatementAll.mock.lastCall;
123+
expect(lastCallArgs?.[6]).toBe(expectedFtsQuery);
124+
});
125+
126+
it("should correctly escape asterisks", async () => {
127+
const query = "wildcard*";
128+
const expectedFtsQuery = '"wildcard*"';
129+
await documentStore.findByContent(library, version, query, limit);
130+
expect(mockStatementAll).toHaveBeenCalledTimes(1);
131+
const lastCallArgs = mockStatementAll.mock.lastCall;
132+
expect(lastCallArgs?.[6]).toBe(expectedFtsQuery);
133+
});
134+
135+
it("should correctly escape already quoted strings", async () => {
136+
const query = '"already quoted"';
137+
const expectedFtsQuery = '"""already quoted"""';
138+
await documentStore.findByContent(library, version, query, limit);
139+
expect(mockStatementAll).toHaveBeenCalledTimes(1);
140+
const lastCallArgs = mockStatementAll.mock.lastCall;
141+
expect(lastCallArgs?.[6]).toBe(expectedFtsQuery);
142+
});
143+
144+
it("should correctly handle empty string", async () => {
145+
const query = "";
146+
const expectedFtsQuery = '""';
147+
await documentStore.findByContent(library, version, query, limit);
148+
expect(mockStatementAll).toHaveBeenCalledTimes(1);
149+
const lastCallArgs = mockStatementAll.mock.lastCall;
150+
expect(lastCallArgs?.[6]).toBe(expectedFtsQuery);
151+
});
152+
});
153+
});

src/store/DocumentStore.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@ export class DocumentStore {
177177
});
178178
}
179179

180+
/**
181+
* Escapes a query string for use with SQLite FTS5 MATCH operator.
182+
* Wraps the query in double quotes and escapes internal double quotes.
183+
*/
184+
private escapeFtsQuery(query: string): string {
185+
// Escape internal double quotes by doubling them
186+
const escapedQuotes = query.replace(/"/g, '""');
187+
// Wrap the entire string in double quotes
188+
return `"${escapedQuotes}"`;
189+
}
190+
180191
/**
181192
* Initializes database connection and ensures readiness
182193
*/
@@ -354,6 +365,7 @@ export class DocumentStore {
354365
): Promise<Document[]> {
355366
try {
356367
const embedding = await this.embeddings.embedQuery(query);
368+
const ftsQuery = this.escapeFtsQuery(query); // Escape the query for FTS
357369

358370
const stmt = this.db.prepare(`
359371
WITH vec_scores AS (
@@ -398,7 +410,7 @@ export class DocumentStore {
398410
limit,
399411
library.toLowerCase(),
400412
version.toLowerCase(),
401-
query,
413+
ftsQuery, // Use the escaped query
402414
limit,
403415
) as RawSearchResult[];
404416

0 commit comments

Comments
 (0)