|
| 1 | +import { afterAll, afterEach, beforeEach, describe, expect, it, vi } from "vitest"; |
| 2 | + |
| 3 | +// --- Mocking Setup --- |
| 4 | + |
| 5 | +// Mock OpenAIEmbeddings |
| 6 | +const mockEmbedQuery = vi.fn().mockResolvedValue([0.1, 0.2, 0.3]); |
| 7 | +const mockEmbedDocuments = vi.fn().mockResolvedValue([[0.1, 0.2, 0.3]]); // Keep this if addDocuments is tested elsewhere |
| 8 | + |
| 9 | +// Mock the module to export a mock function for the class constructor |
| 10 | +vi.mock("@langchain/openai", () => ({ |
| 11 | + OpenAIEmbeddings: vi.fn(), // Mock the class export as a vi.fn() |
| 12 | +})); |
| 13 | + |
| 14 | +// Mock better-sqlite3 |
| 15 | +const mockStatementAll = vi.fn().mockReturnValue([]); |
| 16 | +// Ensure the mock statement object covers methods used by *all* statements prepared in DocumentStore |
| 17 | +const mockStatement = { |
| 18 | + all: mockStatementAll, |
| 19 | + run: vi.fn().mockReturnValue({ changes: 0, lastInsertRowid: 1 }), // Mock run for insert/delete |
| 20 | + get: vi.fn().mockReturnValue(undefined), // Mock get for getById/checkExists etc. |
| 21 | +}; |
| 22 | +const mockPrepare = vi.fn().mockReturnValue(mockStatement); |
| 23 | +const mockDb = { |
| 24 | + prepare: mockPrepare, |
| 25 | + exec: vi.fn(), |
| 26 | + transaction: vi.fn((fn) => fn()), |
| 27 | + close: vi.fn(), |
| 28 | +}; |
| 29 | +vi.mock("better-sqlite3", () => ({ |
| 30 | + default: vi.fn(() => mockDb), // Mock the default export (constructor) |
| 31 | +})); |
| 32 | + |
| 33 | +// Mock sqlite-vec |
| 34 | +vi.mock("sqlite-vec", () => ({ |
| 35 | + load: vi.fn(), |
| 36 | +})); |
| 37 | + |
| 38 | +// --- Test Suite --- |
| 39 | + |
| 40 | +// Import the mocked constructor function |
| 41 | +import { OpenAIEmbeddings } from "@langchain/openai"; |
| 42 | +// Import DocumentStore AFTER mocks are defined |
| 43 | +import { DocumentStore } from "./DocumentStore"; |
| 44 | + |
| 45 | +// Cast OpenAIEmbeddings to the correct Vitest mock type for configuration |
| 46 | +const MockedOpenAIEmbeddingsConstructor = OpenAIEmbeddings as ReturnType<typeof vi.fn>; |
| 47 | + |
| 48 | +describe("DocumentStore", () => { |
| 49 | + let documentStore: DocumentStore; |
| 50 | + |
| 51 | + beforeEach(async () => { |
| 52 | + vi.clearAllMocks(); // Clear call history etc. |
| 53 | + |
| 54 | + // Configure the mock constructor's implementation for THIS test run |
| 55 | + MockedOpenAIEmbeddingsConstructor.mockImplementation(() => ({ |
| 56 | + embedQuery: mockEmbedQuery, |
| 57 | + embedDocuments: mockEmbedDocuments, |
| 58 | + })); |
| 59 | + mockPrepare.mockReturnValue(mockStatement); // <-- Re-configure prepare mock return value |
| 60 | + |
| 61 | + // Now create the store and initialize. |
| 62 | + // initialize() will call 'new OpenAIEmbeddings()', which uses our fresh mock implementation. |
| 63 | + documentStore = new DocumentStore(":memory:"); |
| 64 | + await documentStore.initialize(); |
| 65 | + }); |
| 66 | + |
| 67 | + afterAll(() => { |
| 68 | + vi.restoreAllMocks(); |
| 69 | + }); |
| 70 | + |
| 71 | + describe("findByContent", () => { |
| 72 | + const library = "test-lib"; |
| 73 | + const version = "1.0.0"; |
| 74 | + const limit = 10; |
| 75 | + |
| 76 | + it("should call embedQuery and prepare/all with escaped FTS query for double quotes", async () => { |
| 77 | + const query = 'find "quotes"'; |
| 78 | + const expectedFtsQuery = '"find ""quotes"""'; // Escaped and wrapped |
| 79 | + |
| 80 | + await documentStore.findByContent(library, version, query, limit); |
| 81 | + |
| 82 | + // 1. Check if embedQuery was called |
| 83 | + expect(mockEmbedQuery).toHaveBeenCalledWith(query); |
| 84 | + expect(mockEmbedQuery).toHaveBeenCalledTimes(1); |
| 85 | + |
| 86 | + // 2. Check if db.prepare was called correctly during findByContent |
| 87 | + // It's called multiple times during initialize, so check the specific call |
| 88 | + const prepareCall = mockPrepare.mock.calls.find((call) => |
| 89 | + call[0].includes("WITH vec_scores AS"), |
| 90 | + ); |
| 91 | + expect(prepareCall).toBeDefined(); |
| 92 | + |
| 93 | + // 3. Check the arguments passed to the statement's 'all' method |
| 94 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); // Only the findByContent call should use 'all' |
| 95 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 96 | + expect(lastCallArgs).toEqual([ |
| 97 | + library.toLowerCase(), |
| 98 | + version.toLowerCase(), |
| 99 | + expect.any(String), // Embedding JSON |
| 100 | + limit, |
| 101 | + library.toLowerCase(), |
| 102 | + version.toLowerCase(), |
| 103 | + expectedFtsQuery, // Check the escaped query string |
| 104 | + limit, |
| 105 | + ]); |
| 106 | + }); |
| 107 | + |
| 108 | + it("should correctly escape FTS operators", async () => { |
| 109 | + const query = "search AND this OR that"; |
| 110 | + const expectedFtsQuery = '"search AND this OR that"'; |
| 111 | + await documentStore.findByContent(library, version, query, limit); |
| 112 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); |
| 113 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 114 | + expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); // Check only the FTS query argument |
| 115 | + }); |
| 116 | + |
| 117 | + it("should correctly escape parentheses", async () => { |
| 118 | + const query = "function(arg)"; |
| 119 | + const expectedFtsQuery = '"function(arg)"'; |
| 120 | + await documentStore.findByContent(library, version, query, limit); |
| 121 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); |
| 122 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 123 | + expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); |
| 124 | + }); |
| 125 | + |
| 126 | + it("should correctly escape asterisks", async () => { |
| 127 | + const query = "wildcard*"; |
| 128 | + const expectedFtsQuery = '"wildcard*"'; |
| 129 | + await documentStore.findByContent(library, version, query, limit); |
| 130 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); |
| 131 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 132 | + expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); |
| 133 | + }); |
| 134 | + |
| 135 | + it("should correctly escape already quoted strings", async () => { |
| 136 | + const query = '"already quoted"'; |
| 137 | + const expectedFtsQuery = '"""already quoted"""'; |
| 138 | + await documentStore.findByContent(library, version, query, limit); |
| 139 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); |
| 140 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 141 | + expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); |
| 142 | + }); |
| 143 | + |
| 144 | + it("should correctly handle empty string", async () => { |
| 145 | + const query = ""; |
| 146 | + const expectedFtsQuery = '""'; |
| 147 | + await documentStore.findByContent(library, version, query, limit); |
| 148 | + expect(mockStatementAll).toHaveBeenCalledTimes(1); |
| 149 | + const lastCallArgs = mockStatementAll.mock.lastCall; |
| 150 | + expect(lastCallArgs?.[6]).toBe(expectedFtsQuery); |
| 151 | + }); |
| 152 | + }); |
| 153 | +}); |
0 commit comments