Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,12 @@ Works with any server that implements the OpenAI `/v1/embeddings` API format (ll
"dimensions": 768,
"apiKey": "{env:EMBED_API_KEY}",
"maxTokens": 8192,
"timeoutMs": 30000
"timeoutMs": 30000,
"maxBatchSize": 64
}
}
```
Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000). `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed.
Required fields: `baseUrl`, `model`, `dimensions` (positive integer). Optional: `apiKey`, `maxTokens`, `timeoutMs` (default: 30000), `maxBatchSize` (or `max_batch_size`) to cap inputs per `/embeddings` request for servers like text-embeddings-inference. `{env:VAR_NAME}` placeholders are resolved before config validation for fields that are actually used and throw if the referenced environment variable is missing or malformed.

## ⚠️ Tradeoffs

Expand Down
7 changes: 7 additions & 0 deletions src/config/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export interface CustomProviderConfig {
concurrency?: number;
/** Minimum delay between requests in milliseconds (default: 1000). Set to 0 for local servers. */
requestIntervalMs?: number;
maxBatchSize?: number;
max_batch_size?: number;
}

export interface CodebaseIndexConfig {
Expand Down Expand Up @@ -245,6 +247,11 @@ export function parseConfig(raw: unknown): ParsedCodebaseIndexConfig {
timeoutMs: typeof rawCustom.timeoutMs === 'number' ? Math.max(1000, rawCustom.timeoutMs) : undefined,
concurrency: typeof rawCustom.concurrency === 'number' ? Math.max(1, Math.floor(rawCustom.concurrency)) : undefined,
requestIntervalMs: typeof rawCustom.requestIntervalMs === 'number' ? Math.max(0, Math.floor(rawCustom.requestIntervalMs)) : undefined,
maxBatchSize: typeof rawCustom.maxBatchSize === 'number'
? Math.max(1, Math.floor(rawCustom.maxBatchSize))
: typeof rawCustom.max_batch_size === 'number'
? Math.max(1, Math.floor(rawCustom.max_batch_size))
: undefined,
};
// Warn if baseUrl doesn't end with an API version path like /v1.
// Note: using console.warn here because Logger isn't initialized yet at config parse time.
Expand Down
2 changes: 2 additions & 0 deletions src/embeddings/detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export interface ProviderCredentials {
export interface CustomModelInfo extends BaseModelInfo {
provider: 'custom';
timeoutMs: number;
maxBatchSize?: number;
}

export type ConfiguredProviderInfo = {
Expand Down Expand Up @@ -247,6 +248,7 @@ export function createCustomProviderInfo(config: CustomProviderConfig): Configur
maxTokens: config.maxTokens ?? 8192,
costPer1MTokens: 0,
timeoutMs: config.timeoutMs ?? 30_000,
maxBatchSize: config.maxBatchSize,
},
};
}
66 changes: 52 additions & 14 deletions src/embeddings/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,23 +343,28 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface {
private modelInfo: CustomModelInfo
) { }

async embedQuery(query: string): Promise<EmbeddingResult> {
const result = await this.embedBatch([query]);
return {
embedding: result.embeddings[0],
tokensUsed: result.totalTokensUsed,
};
}
private splitIntoRequestBatches(texts: string[]): string[][] {
const maxBatchSize = this.modelInfo.maxBatchSize;

async embedDocument(document: string): Promise<EmbeddingResult> {
const result = await this.embedBatch([document]);
return {
embedding: result.embeddings[0],
tokensUsed: result.totalTokensUsed,
};
if (!maxBatchSize || texts.length <= maxBatchSize) {
return [texts];
}

const batches: string[][] = [];
for (let i = 0; i < texts.length; i += maxBatchSize) {
batches.push(texts.slice(i, i + maxBatchSize));
}
return batches;
}

async embedBatch(texts: string[]): Promise<EmbeddingBatchResult> {
private async embedRequest(texts: string[]): Promise<EmbeddingBatchResult> {
if (texts.length === 0) {
return {
embeddings: [],
totalTokensUsed: 0,
};
}

const headers: Record<string, string> = {
"Content-Type": "application/json",
};
Expand Down Expand Up @@ -444,6 +449,39 @@ class CustomEmbeddingProvider implements EmbeddingProviderInterface {
throw new Error("Custom embedding API returned unexpected response format. Expected OpenAI-compatible format with data[].embedding.");
}

async embedQuery(query: string): Promise<EmbeddingResult> {
const result = await this.embedBatch([query]);
return {
embedding: result.embeddings[0],
tokensUsed: result.totalTokensUsed,
};
}

async embedDocument(document: string): Promise<EmbeddingResult> {
const result = await this.embedBatch([document]);
return {
embedding: result.embeddings[0],
tokensUsed: result.totalTokensUsed,
};
}

async embedBatch(texts: string[]): Promise<EmbeddingBatchResult> {
const requestBatches = this.splitIntoRequestBatches(texts);
const embeddings: number[][] = [];
let totalTokensUsed = 0;

for (const batch of requestBatches) {
const result = await this.embedRequest(batch);
embeddings.push(...result.embeddings);
totalTokensUsed += result.totalTokensUsed;
}

return {
embeddings,
totalTokensUsed,
};
}

getModelInfo(): CustomModelInfo {
return this.modelInfo;
}
Expand Down
39 changes: 39 additions & 0 deletions tests/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,32 @@ describe("config schema", () => {
expect(config.customProvider!.requestIntervalMs).toBe(0);
});

it("should parse custom provider with maxBatchSize", () => {
const config = parseConfig({
embeddingProvider: "custom",
customProvider: {
baseUrl: "http://localhost:11434/v1",
model: "test",
dimensions: 768,
maxBatchSize: 64,
},
});
expect(config.customProvider!.maxBatchSize).toBe(64);
});

it("should parse custom provider with max_batch_size alias", () => {
const config = parseConfig({
embeddingProvider: "custom",
customProvider: {
baseUrl: "http://localhost:11434/v1",
model: "test",
dimensions: 768,
max_batch_size: 32,
},
});
expect(config.customProvider!.maxBatchSize).toBe(32);
});

it("should clamp concurrency to minimum of 1", () => {
const config = parseConfig({
embeddingProvider: "custom",
Expand All @@ -636,6 +662,19 @@ describe("config schema", () => {
expect(config.customProvider!.concurrency).toBe(1);
});

it("should clamp maxBatchSize to minimum of 1", () => {
const config = parseConfig({
embeddingProvider: "custom",
customProvider: {
baseUrl: "http://localhost:11434/v1",
model: "test",
dimensions: 768,
maxBatchSize: 0,
},
});
expect(config.customProvider!.maxBatchSize).toBe(1);
});

it("should leave concurrency undefined when not provided", () => {
const config = parseConfig({
embeddingProvider: "custom",
Expand Down
85 changes: 75 additions & 10 deletions tests/custom-provider.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { createEmbeddingProvider, CustomProviderNonRetryableError } from "../src/embeddings/provider.js";
import { createCustomProviderInfo } from "../src/embeddings/detector.js";
import { createCustomProviderInfo, type ConfiguredProviderInfo } from "../src/embeddings/detector.js";
import { Indexer } from "../src/indexer/index.js";
import { parseConfig } from "../src/config/schema.js";
import pRetry from "p-retry";
Expand All @@ -11,6 +11,30 @@ import * as path from "path";
describe("CustomEmbeddingProvider", () => {
let fetchSpy: ReturnType<typeof vi.spyOn>;

function getCustomProviderInfo(
info: ConfiguredProviderInfo
): Extract<ConfiguredProviderInfo, { provider: "custom" }> {
expect(info.provider).toBe("custom");
if (info.provider !== "custom") {
throw new Error("Expected custom provider info");
}
return info;
}

function getRejectedError<T>(promise: Promise<T>): Promise<Error> {
return promise.then<Error>(
() => {
throw new Error("Expected promise to reject");
},
(error: unknown) => {
if (error instanceof Error) {
return error;
}
return new Error(String(error));
}
);
}

beforeEach(() => {
fetchSpy = vi.spyOn(globalThis, "fetch");
});
Expand Down Expand Up @@ -94,6 +118,47 @@ describe("CustomEmbeddingProvider", () => {
expect(result.totalTokensUsed).toBe(30);
});

it("should split custom provider requests by maxBatchSize", async () => {
fetchSpy
.mockResolvedValueOnce(new Response(JSON.stringify({
data: [
{ embedding: new Array(768).fill(0.1) },
{ embedding: new Array(768).fill(0.2) },
],
usage: { total_tokens: 20 },
}), { status: 200 }))
.mockResolvedValueOnce(new Response(JSON.stringify({
data: [
{ embedding: new Array(768).fill(0.3) },
{ embedding: new Array(768).fill(0.4) },
],
usage: { total_tokens: 22 },
}), { status: 200 }))
.mockResolvedValueOnce(new Response(JSON.stringify({
data: [
{ embedding: new Array(768).fill(0.5) },
],
usage: { total_tokens: 11 },
}), { status: 200 }));

const info = createCustomProviderInfo({
baseUrl: "http://localhost:11434/v1",
model: "nomic-embed-text",
dimensions: 768,
maxBatchSize: 2,
});
const provider = createEmbeddingProvider(info);

const result = await provider.embedBatch(["text1", "text2", "text3", "text4", "text5"]);

expect(fetchSpy).toHaveBeenCalledTimes(3);
expect(JSON.parse((fetchSpy.mock.calls[0] as [string, RequestInit])[1].body as string).input).toEqual(["text1", "text2"]);
expect(JSON.parse((fetchSpy.mock.calls[1] as [string, RequestInit])[1].body as string).input).toEqual(["text3", "text4"]);
expect(JSON.parse((fetchSpy.mock.calls[2] as [string, RequestInit])[1].body as string).input).toEqual(["text5"]);
expect(result.embeddings).toHaveLength(5);
expect(result.totalTokensUsed).toBe(53);
});

it("should estimate tokens when usage is not provided", async () => {
fetchSpy.mockResolvedValueOnce(new Response(JSON.stringify({
data: [{ embedding: new Array(768).fill(0) }],
Expand Down Expand Up @@ -213,28 +278,28 @@ describe("CustomEmbeddingProvider", () => {
});

it("should default timeout to 30000ms", () => {
const info = createCustomProviderInfo({
const info = getCustomProviderInfo(createCustomProviderInfo({
baseUrl: "http://localhost:11434/v1",
model: "nomic-embed-text",
dimensions: 768,
});
}));
expect(info.modelInfo.timeoutMs).toBe(30000);
});

it("should use custom timeout value from config", () => {
const info = createCustomProviderInfo({
const info = getCustomProviderInfo(createCustomProviderInfo({
baseUrl: "http://localhost:11434/v1",
model: "nomic-embed-text",
dimensions: 768,
timeoutMs: 60000,
});
}));
expect(info.modelInfo.timeoutMs).toBe(60000);
});

it("should throw non-retryable error on 4xx responses (except 429)", async () => {
fetchSpy.mockResolvedValueOnce(new Response("Unauthorized", { status: 401 }));
const provider = createProvider();
const error = await provider.embedQuery("test").catch((e: Error) => e);
const error = await getRejectedError(provider.embedQuery("test"));
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
expect(error.message).toContain("non-retryable");
expect(error.message).toContain("401");
Expand All @@ -243,29 +308,29 @@ describe("CustomEmbeddingProvider", () => {
it("should throw non-retryable error on 400 Bad Request", async () => {
fetchSpy.mockResolvedValueOnce(new Response("Bad model name", { status: 400 }));
const provider = createProvider();
const error = await provider.embedQuery("test").catch((e: Error) => e);
const error = await getRejectedError(provider.embedQuery("test"));
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
});

it("should throw non-retryable error on 403 Forbidden", async () => {
fetchSpy.mockResolvedValueOnce(new Response("Forbidden", { status: 403 }));
const provider = createProvider();
const error = await provider.embedQuery("test").catch((e: Error) => e);
const error = await getRejectedError(provider.embedQuery("test"));
expect(error).toBeInstanceOf(CustomProviderNonRetryableError);
});

it("should throw retryable error on 429 rate limit", async () => {
fetchSpy.mockResolvedValueOnce(new Response("Rate limited", { status: 429 }));
const provider = createProvider();
const error = await provider.embedQuery("test").catch((e: Error) => e);
const error = await getRejectedError(provider.embedQuery("test"));
expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError);
expect(error.message).toContain("429");
});

it("should throw retryable error on 5xx server errors", async () => {
fetchSpy.mockResolvedValueOnce(new Response("Internal Server Error", { status: 500 }));
const provider = createProvider();
const error = await provider.embedQuery("test").catch((e: Error) => e);
const error = await getRejectedError(provider.embedQuery("test"));
expect(error).not.toBeInstanceOf(CustomProviderNonRetryableError);
expect(error.message).toContain("500");
});
Expand Down
10 changes: 10 additions & 0 deletions tests/embeddings.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,15 @@ describe("embeddings detector", () => {
});
expect(info.modelInfo.maxTokens).toBe(4096);
});

it("should pass through optional maxBatchSize", () => {
const info = createCustomProviderInfo({
baseUrl: "http://localhost/v1",
model: "test",
dimensions: 512,
maxBatchSize: 64,
});
expect(info.modelInfo.maxBatchSize).toBe(64);
});
});
});
Loading