Skip to content

Commit 3c26236

Browse files
authored
feat: add custom provider batch size caps (#41)
1 parent 4898715 commit 3c26236

File tree

7 files changed

+188
-26
lines changed

7 files changed

+188
-26
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,11 +623,12 @@ Works with any server that implements the OpenAI `/v1/embeddings` API format (ll
623623
"dimensions": 768,
624624
"apiKey": "{env:EMBED_API_KEY}",
625625
"maxTokens": 8192,
626-
"timeoutMs": 30000
626+
"timeoutMs": 30000,
627+
"maxBatchSize": 64
627628
}
628629
}
629630
```
630-
Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000). `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed.
631+
Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000), `maxBatchSize` (or `max_batch_size`) to cap inputs per `/embeddings` request for servers like text-embeddings-inference. `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed.
631632

632633
## ⚠️ Tradeoffs
633634

src/config/schema.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ export interface CustomProviderConfig {
6666
concurrency?: number;
6767
/** Minimum delay between requests in milliseconds (default: 1000). Set to 0 for local servers. */
6868
requestIntervalMs?: number;
69+
maxBatchSize?: number;
70+
max_batch_size?: number;
6971
}
7072

7173
export interface CodebaseIndexConfig {
@@ -245,6 +247,11 @@ export function parseConfig(raw: unknown): ParsedCodebaseIndexConfig {
245247
timeoutMs: typeof rawCustom.timeoutMs === 'number' ? Math.max(1000, rawCustom.timeoutMs) : undefined,
246248
concurrency: typeof rawCustom.concurrency === 'number' ? Math.max(1, Math.floor(rawCustom.concurrency)) : undefined,
247249
requestIntervalMs: typeof rawCustom.requestIntervalMs === 'number' ? Math.max(0, Math.floor(rawCustom.requestIntervalMs)) : undefined,
250+
maxBatchSize: typeof rawCustom.maxBatchSize === 'number'
251+
? Math.max(1, Math.floor(rawCustom.maxBatchSize))
252+
: typeof rawCustom.max_batch_size === 'number'
253+
? Math.max(1, Math.floor(rawCustom.max_batch_size))
254+
: undefined,
248255
};
249256
// Warn if baseUrl doesn't end with an API version path like /v1.
250257
// Note: using console.warn here because Logger isn't initialized yet at config parse time.

src/embeddings/detector.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ export interface ProviderCredentials {
1515
export interface CustomModelInfo extends BaseModelInfo {
1616
provider: 'custom';
1717
timeoutMs: number;
18+
maxBatchSize?: number;
1819
}
1920

2021
export type ConfiguredProviderInfo = {
@@ -247,6 +248,7 @@ export function createCustomProviderInfo(config: CustomProviderConfig): Configur
247248
maxTokens: config.maxTokens ?? 8192,
248249
costPer1MTokens: 0,
249250
timeoutMs: config.timeoutMs ?? 30_000,
251+
maxBatchSize: config.maxBatchSize,
250252
},
251253
};
252254
}

src/embeddings/provider.ts

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -343,23 +343,28 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface {
343343
private modelInfo: CustomModelInfo
344344
) { }
345345

346-
async embedQuery(query: string): Promise<EmbeddingResult> {
347-
const result = await this.embedBatch([query]);
348-
return {
349-
embedding: result.embeddings[0],
350-
tokensUsed: result.totalTokensUsed,
351-
};
352-
}
346+
private splitIntoRequestBatches(texts: string[]): string[][] {
347+
const maxBatchSize = this.modelInfo.maxBatchSize;
353348

354-
async embedDocument(document: string): Promise<EmbeddingResult> {
355-
const result = await this.embedBatch([document]);
356-
return {
357-
embedding: result.embeddings[0],
358-
tokensUsed: result.totalTokensUsed,
359-
};
349+
if (!maxBatchSize || texts.length <= maxBatchSize) {
350+
return [texts];
351+
}
352+
353+
const batches: string[][] = [];
354+
for (let i = 0; i < texts.length; i += maxBatchSize) {
355+
batches.push(texts.slice(i, i + maxBatchSize));
356+
}
357+
return batches;
360358
}
361359

362-
async embedBatch(texts: string[]): Promise<EmbeddingBatchResult> {
360+
private async embedRequest(texts: string[]): Promise<EmbeddingBatchResult> {
361+
if (texts.length === 0) {
362+
return {
363+
embeddings: [],
364+
totalTokensUsed: 0,
365+
};
366+
}
367+
363368
const headers: Record<string, string> = {
364369
"Content-Type": "application/json",
365370
};
@@ -444,6 +449,39 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface {
444449
throw new Error("Custom embedding API returned unexpected response format. Expected OpenAI-compatible format with data[].embedding.");
445450
}
446451

452+
async embedQuery(query: string): Promise<EmbeddingResult> {
453+
const result = await this.embedBatch([query]);
454+
return {
455+
embedding: result.embeddings[0],
456+
tokensUsed: result.totalTokensUsed,
457+
};
458+
}
459+
460+
async embedDocument(document: string): Promise<EmbeddingResult> {
461+
const result = await this.embedBatch([document]);
462+
return {
463+
embedding: result.embeddings[0],
464+
tokensUsed: result.totalTokensUsed,
465+
};
466+
}
467+
468+
async embedBatch(texts: string[]): Promise<EmbeddingBatchResult> {
469+
const requestBatches = this.splitIntoRequestBatches(texts);
470+
const embeddings: number[][] = [];
471+
let totalTokensUsed = 0;
472+
473+
for (const batch of requestBatches) {
474+
const result = await this.embedRequest(batch);
475+
embeddings.push(...result.embeddings);
476+
totalTokensUsed += result.totalTokensUsed;
477+
}
478+
479+
return {
480+
embeddings,
481+
totalTokensUsed,
482+
};
483+
}
484+
447485
getModelInfo(): CustomModelInfo {
448486
return this.modelInfo;
449487
}

tests/config.test.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,32 @@ describe("config schema", () => {
623623
expect(config.customProvider!.requestIntervalMs).toBe(0);
624624
});
625625

626+
it("should parse custom provider with maxBatchSize", () => {
627+
const config = parseConfig({
628+
embeddingProvider: "custom",
629+
customProvider: {
630+
baseUrl: "http://localhost:11434/v1",
631+
model: "test",
632+
dimensions: 768,
633+
maxBatchSize: 64,
634+
},
635+
});
636+
expect(config.customProvider!.maxBatchSize).toBe(64);
637+
});
638+
639+
it("should parse custom provider with max_batch_size alias", () => {
640+
const config = parseConfig({
641+
embeddingProvider: "custom",
642+
customProvider: {
643+
baseUrl: "http://localhost:11434/v1",
644+
model: "test",
645+
dimensions: 768,
646+
max_batch_size: 32,
647+
},
648+
});
649+
expect(config.customProvider!.maxBatchSize).toBe(32);
650+
});
651+
626652
it("should clamp concurrency to minimum of 1", () => {
627653
const config = parseConfig({
628654
embeddingProvider: "custom",
@@ -636,6 +662,19 @@ describe("config schema", () => {
636662
expect(config.customProvider!.concurrency).toBe(1);
637663
});
638664

665+
it("should clamp maxBatchSize to minimum of 1", () => {
666+
const config = parseConfig({
667+
embeddingProvider: "custom",
668+
customProvider: {
669+
baseUrl: "http://localhost:11434/v1",
670+
model: "test",
671+
dimensions: 768,
672+
maxBatchSize: 0,
673+
},
674+
});
675+
expect(config.customProvider!.maxBatchSize).toBe(1);
676+
});
677+
639678
it("should leave concurrency undefined when not provided", () => {
640679
const config = parseConfig({
641680
embeddingProvider: "custom",

tests/custom-provider.test.ts

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
22
import { createEmbeddingProvider, CustomProviderNonRetryableError } from "../src/embeddings/provider.js";
3-
import { createCustomProviderInfo } from "../src/embeddings/detector.js";
3+
import { createCustomProviderInfo, type ConfiguredProviderInfo } from "../src/embeddings/detector.js";
44
import { Indexer } from "../src/indexer/index.js";
55
import { parseConfig } from "../src/config/schema.js";
66
import pRetry from "p-retry";
@@ -11,6 +11,30 @@ import * as path from "path";
1111
describe("CustomEmbeddingProvider", () => {
1212
let fetchSpy: ReturnType<typeof vi.spyOn>;
1313

14+
function getCustomProviderInfo(
15+
info: ConfiguredProviderInfo
16+
): Extract<ConfiguredProviderInfo, { provider: "custom" }> {
17+
expect(info.provider).toBe("custom");
18+
if (info.provider !== "custom") {
19+
throw new Error("Expected custom provider info");
20+
}
21+
return info;
22+
}
23+
24+
function getRejectedError<T>(promise: Promise<T>): Promise<Error> {
25+
return promise.then<Error>(
26+
() => {
27+
throw new Error("Expected promise to reject");
28+
},
29+
(error: unknown) => {
30+
if (error instanceof Error) {
31+
return error;
32+
}
33+
return new Error(String(error));
34+
}
35+
);
36+
}
37+
1438
beforeEach(() => {
1539
fetchSpy = vi.spyOn(globalThis, "fetch");
1640
});
@@ -94,6 +118,47 @@ describe("CustomEmbeddingProvider", () => {
94118
expect(result.totalTokensUsed).toBe(30);
95119
});
96120

121+
it("should split custom provider requests by maxBatchSize", async () => {
122+
fetchSpy
123+
.mockResolvedValueOnce(new Response(JSON.stringify({
124+
data: [
125+
{ embedding: new Array(768).fill(0.1) },
126+
{ embedding: new Array(768).fill(0.2) },
127+
],
128+
usage: { total_tokens: 20 },
129+
}), { status: 200 }))
130+
.mockResolvedValueOnce(new Response(JSON.stringify({
131+
data: [
132+
{ embedding: new Array(768).fill(0.3) },
133+
{ embedding: new Array(768).fill(0.4) },
134+
],
135+
usage: { total_tokens: 22 },
136+
}), { status: 200 }))
137+
.mockResolvedValueOnce(new Response(JSON.stringify({
138+
data: [
139+
{ embedding: new Array(768).fill(0.5) },
140+
],
141+
usage: { total_tokens: 11 },
142+
}), { status: 200 }));
143+
144+
const info = createCustomProviderInfo({
145+
baseUrl: "http://localhost:11434/v1",
146+
model: "nomic-embed-text",
147+
dimensions: 768,
148+
maxBatchSize: 2,
149+
});
150+
const provider = createEmbeddingProvider(info);
151+
152+
const result = await provider.embedBatch(["text1", "text2", "text3", "text4", "text5"]);
153+
154+
expect(fetchSpy).toHaveBeenCalledTimes(3);
155+
expect(JSON.parse((fetchSpy.mock.calls[0] as [string, RequestInit])[1].body as string).input).toEqual(["text1", "text2"]);
156+
expect(JSON.parse((fetchSpy.mock.calls[1] as [string, RequestInit])[1].body as string).input).toEqual(["text3", "text4"]);
157+
expect(JSON.parse((fetchSpy.mock.calls[2] as [string, RequestInit])[1].body as string).input).toEqual(["text5"]);
158+
expect(result.embeddings).toHaveLength(5);
159+
expect(result.totalTokensUsed).toBe(53);
160+
});
161+
97162
it("should estimate tokens when usage is not provided", async () => {
98163
fetchSpy.mockResolvedValueOnce(new Response(JSON.stringify({
99164
data: [{ embedding: new Array(768).fill(0) }],
@@ -213,28 +278,28 @@ describe("CustomEmbeddingProvider", () => {
213278
});
214279

215280
it("should default timeout to 30000ms", () => {
216-
const info = createCustomProviderInfo({
281+
const info = getCustomProviderInfo(createCustomProviderInfo({
217282
baseUrl: "http://localhost:11434/v1",
218283
model: "nomic-embed-text",
219284
dimensions: 768,
220-
});
285+
}));
221286
expect(info.modelInfo.timeoutMs).toBe(30000);
222287
});
223288

224289
it("should use custom timeout value from config", () => {
225-
const info = createCustomProviderInfo({
290+
const info = getCustomProviderInfo(createCustomProviderInfo({
226291
baseUrl: "http://localhost:11434/v1",
227292
model: "nomic-embed-text",
228293
dimensions: 768,
229294
timeoutMs: 60000,
230-
});
295+
}));
231296
expect(info.modelInfo.timeoutMs).toBe(60000);
232297
});
233298

234299
it("should throw non-retryable error on 4xx responses (except 429)", async () => {
235300
fetchSpy.mockResolvedValueOnce(new Response("Unauthorized", { status: 401 }));
236301
const provider = createProvider();
237-
const error = await provider.embedQuery("test").catch((e: Error) => e);
302+
const error = await getRejectedError(provider.embedQuery("test"));
238303
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
239304
expect(error.message).toContain("non-retryable");
240305
expect(error.message).toContain("401");
@@ -243,29 +308,29 @@ describe("CustomEmbeddingProvider", () => {
243308
it("should throw non-retryable error on 400 Bad Request", async () => {
244309
fetchSpy.mockResolvedValueOnce(new Response("Bad model name", { status: 400 }));
245310
const provider = createProvider();
246-
const error = await provider.embedQuery("test").catch((e: Error) => e);
311+
const error = await getRejectedError(provider.embedQuery("test"));
247312
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
248313
});
249314

250315
it("should throw non-retryable error on 403 Forbidden", async () => {
251316
fetchSpy.mockResolvedValueOnce(new Response("Forbidden", { status: 403 }));
252317
const provider = createProvider();
253-
const error = await provider.embedQuery("test").catch((e: Error) => e);
318+
const error = await getRejectedError(provider.embedQuery("test"));
254319
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
255320
});
256321

257322
it("should throw retryable error on 429 rate limit", async () => {
258323
fetchSpy.mockResolvedValueOnce(new Response("Rate limited", { status: 429 }));
259324
const provider = createProvider();
260-
const error = await provider.embedQuery("test").catch((e: Error) => e);
325+
const error = await getRejectedError(provider.embedQuery("test"));
261326
expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError);
262327
expect(error.message).toContain("429");
263328
});
264329

265330
it("should throw retryable error on 5xx server errors", async () => {
266331
fetchSpy.mockResolvedValueOnce(new Response("Internal Server Error", { status: 500 }));
267332
const provider = createProvider();
268-
const error = await provider.embedQuery("test").catch((e: Error) => e);
333+
const error = await getRejectedError(provider.embedQuery("test"));
269334
expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError);
270335
expect(error.message).toContain("500");
271336
});

tests/embeddings.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,15 @@ describe("embeddings detector", () => {
6767
});
6868
expect(info.modelInfo.maxTokens).toBe(4096);
6969
});
70+
71+
it("should pass through optional maxBatchSize", () => {
72+
const info = createCustomProviderInfo({
73+
baseUrl: "http://localhost/v1",
74+
model: "test",
75+
dimensions: 512,
76+
maxBatchSize: 64,
77+
});
78+
expect(info.modelInfo.maxBatchSize).toBe(64);
79+
});
7080
});
7181
});

0 commit comments

Comments
 (0)