11import { describe , it , expect , vi , beforeEach , afterEach } from "vitest" ;
22import { createEmbeddingProvider , CustomProviderNonRetryableError } from "../src/embeddings/provider.js" ;
3- import { createCustomProviderInfo } from "../src/embeddings/detector.js" ;
3+ import { createCustomProviderInfo , type ConfiguredProviderInfo } from "../src/embeddings/detector.js" ;
44import { Indexer } from "../src/indexer/index.js" ;
55import { parseConfig } from "../src/config/schema.js" ;
66import pRetry from "p-retry" ;
@@ -11,6 +11,30 @@ import * as path from "path";
1111describe ( "CustomEmbeddingProvider" , ( ) => {
1212 let fetchSpy : ReturnType < typeof vi . spyOn > ;
1313
14+ function getCustomProviderInfo (
15+ info : ConfiguredProviderInfo
16+ ) : Extract < ConfiguredProviderInfo , { provider : "custom" } > {
17+ expect ( info . provider ) . toBe ( "custom" ) ;
18+ if ( info . provider !== "custom" ) {
19+ throw new Error ( "Expected custom provider info" ) ;
20+ }
21+ return info ;
22+ }
23+
24+ function getRejectedError < T > ( promise : Promise < T > ) : Promise < Error > {
25+ return promise . then < Error > (
26+ ( ) => {
27+ throw new Error ( "Expected promise to reject" ) ;
28+ } ,
29+ ( error : unknown ) => {
30+ if ( error instanceof Error ) {
31+ return error ;
32+ }
33+ return new Error ( String ( error ) ) ;
34+ }
35+ ) ;
36+ }
37+
1438 beforeEach ( ( ) => {
1539 fetchSpy = vi . spyOn ( globalThis , "fetch" ) ;
1640 } ) ;
@@ -94,6 +118,47 @@ describe("CustomEmbeddingProvider", () => {
94118 expect ( result . totalTokensUsed ) . toBe ( 30 ) ;
95119 } ) ;
96120
121+ it ( "should split custom provider requests by maxBatchSize" , async ( ) => {
122+ fetchSpy
123+ . mockResolvedValueOnce ( new Response ( JSON . stringify ( {
124+ data : [
125+ { embedding : new Array ( 768 ) . fill ( 0.1 ) } ,
126+ { embedding : new Array ( 768 ) . fill ( 0.2 ) } ,
127+ ] ,
128+ usage : { total_tokens : 20 } ,
129+ } ) , { status : 200 } ) )
130+ . mockResolvedValueOnce ( new Response ( JSON . stringify ( {
131+ data : [
132+ { embedding : new Array ( 768 ) . fill ( 0.3 ) } ,
133+ { embedding : new Array ( 768 ) . fill ( 0.4 ) } ,
134+ ] ,
135+ usage : { total_tokens : 22 } ,
136+ } ) , { status : 200 } ) )
137+ . mockResolvedValueOnce ( new Response ( JSON . stringify ( {
138+ data : [
139+ { embedding : new Array ( 768 ) . fill ( 0.5 ) } ,
140+ ] ,
141+ usage : { total_tokens : 11 } ,
142+ } ) , { status : 200 } ) ) ;
143+
144+ const info = createCustomProviderInfo ( {
145+ baseUrl : "http://localhost:11434/v1" ,
146+ model : "nomic-embed-text" ,
147+ dimensions : 768 ,
148+ maxBatchSize : 2 ,
149+ } ) ;
150+ const provider = createEmbeddingProvider ( info ) ;
151+
152+ const result = await provider . embedBatch ( [ "text1" , "text2" , "text3" , "text4" , "text5" ] ) ;
153+
154+ expect ( fetchSpy ) . toHaveBeenCalledTimes ( 3 ) ;
155+ expect ( JSON . parse ( ( fetchSpy . mock . calls [ 0 ] as [ string , RequestInit ] ) [ 1 ] . body as string ) . input ) . toEqual ( [ "text1" , "text2" ] ) ;
156+ expect ( JSON . parse ( ( fetchSpy . mock . calls [ 1 ] as [ string , RequestInit ] ) [ 1 ] . body as string ) . input ) . toEqual ( [ "text3" , "text4" ] ) ;
157+ expect ( JSON . parse ( ( fetchSpy . mock . calls [ 2 ] as [ string , RequestInit ] ) [ 1 ] . body as string ) . input ) . toEqual ( [ "text5" ] ) ;
158+ expect ( result . embeddings ) . toHaveLength ( 5 ) ;
159+ expect ( result . totalTokensUsed ) . toBe ( 53 ) ;
160+ } ) ;
161+
97162 it ( "should estimate tokens when usage is not provided" , async ( ) => {
98163 fetchSpy . mockResolvedValueOnce ( new Response ( JSON . stringify ( {
99164 data : [ { embedding : new Array ( 768 ) . fill ( 0 ) } ] ,
@@ -213,28 +278,28 @@ describe("CustomEmbeddingProvider", () => {
213278 } ) ;
214279
215280 it ( "should default timeout to 30000ms" , ( ) => {
216- const info = createCustomProviderInfo ( {
281+ const info = getCustomProviderInfo ( createCustomProviderInfo ( {
217282 baseUrl : "http://localhost:11434/v1" ,
218283 model : "nomic-embed-text" ,
219284 dimensions : 768 ,
220- } ) ;
285+ } ) ) ;
221286 expect ( info . modelInfo . timeoutMs ) . toBe ( 30000 ) ;
222287 } ) ;
223288
224289 it ( "should use custom timeout value from config" , ( ) => {
225- const info = createCustomProviderInfo ( {
290+ const info = getCustomProviderInfo ( createCustomProviderInfo ( {
226291 baseUrl : "http://localhost:11434/v1" ,
227292 model : "nomic-embed-text" ,
228293 dimensions : 768 ,
229294 timeoutMs : 60000 ,
230- } ) ;
295+ } ) ) ;
231296 expect ( info . modelInfo . timeoutMs ) . toBe ( 60000 ) ;
232297 } ) ;
233298
234299 it ( "should throw non-retryable error on 4xx responses (except 429)" , async ( ) => {
235300 fetchSpy . mockResolvedValueOnce ( new Response ( "Unauthorized" , { status : 401 } ) ) ;
236301 const provider = createProvider ( ) ;
237- const error = await provider . embedQuery ( "test" ) . catch ( ( e : Error ) => e ) ;
302+ const error = await getRejectedError ( provider . embedQuery ( "test" ) ) ;
238303 expect ( error ) . toBeInstanceOf ( CustomProviderNonRetryableError ) ;
239304 expect ( error . message ) . toContain ( "non-retryable" ) ;
240305 expect ( error . message ) . toContain ( "401" ) ;
@@ -243,29 +308,29 @@ describe("CustomEmbeddingProvider", () => {
243308 it ( "should throw non-retryable error on 400 Bad Request" , async ( ) => {
244309 fetchSpy . mockResolvedValueOnce ( new Response ( "Bad model name" , { status : 400 } ) ) ;
245310 const provider = createProvider ( ) ;
246- const error = await provider . embedQuery ( "test" ) . catch ( ( e : Error ) => e ) ;
311+ const error = await getRejectedError ( provider . embedQuery ( "test" ) ) ;
247312 expect ( error ) . toBeInstanceOf ( CustomProviderNonRetryableError ) ;
248313 } ) ;
249314
250315 it ( "should throw non-retryable error on 403 Forbidden" , async ( ) => {
251316 fetchSpy . mockResolvedValueOnce ( new Response ( "Forbidden" , { status : 403 } ) ) ;
252317 const provider = createProvider ( ) ;
253- const error = await provider . embedQuery ( "test" ) . catch ( ( e : Error ) => e ) ;
318+ const error = await getRejectedError ( provider . embedQuery ( "test" ) ) ;
254319 expect ( error ) . toBeInstanceOf ( CustomProviderNonRetryableError ) ;
255320 } ) ;
256321
257322 it ( "should throw retryable error on 429 rate limit" , async ( ) => {
258323 fetchSpy . mockResolvedValueOnce ( new Response ( "Rate limited" , { status : 429 } ) ) ;
259324 const provider = createProvider ( ) ;
260- const error = await provider . embedQuery ( "test" ) . catch ( ( e : Error ) => e ) ;
325+ const error = await getRejectedError ( provider . embedQuery ( "test" ) ) ;
261326 expect ( error ) . not . toBeInstanceOf ( CustomProviderNonRetryableError ) ;
262327 expect ( error . message ) . toContain ( "429" ) ;
263328 } ) ;
264329
265330 it ( "should throw retryable error on 5xx server errors" , async ( ) => {
266331 fetchSpy . mockResolvedValueOnce ( new Response ( "Internal Server Error" , { status : 500 } ) ) ;
267332 const provider = createProvider ( ) ;
268- const error = await provider . embedQuery ( "test" ) . catch ( ( e : Error ) => e ) ;
333+ const error = await getRejectedError ( provider . embedQuery ( "test" ) ) ;
269334 expect ( error ) . not . toBeInstanceOf ( CustomProviderNonRetryableError ) ;
270335 expect ( error . message ) . toContain ( "500" ) ;
271336 } ) ;
0 commit comments