Skip to content

Commit 361e36f

Browse files
authored
Merge pull request #24 from verisoft-ai/feat/extra_llm_support_vision
Feat/extra llm support vision
2 parents 87f78b1 + 741ff08 commit 361e36f

File tree

5 files changed

+376
-19
lines changed

5 files changed

+376
-19
lines changed

lib/commands/vision.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import {
77
buildVisionPrompt,
88
callVisionLLM,
99
computeCoordMapping,
10+
getApiKeyEnvVar,
11+
getProviderForModel,
1012
parseVisionCoords,
1113
} from '../vision-utils';
1214

@@ -44,17 +46,18 @@ export async function executeFindByVision(
4446
this: AppiumDesktopDriver,
4547
args: { prompt: string; model?: string },
4648
): Promise<{ x: number; y: number; label: string }> {
47-
const apiKey = process.env.ANTHROPIC_API_KEY;
49+
const model = args.model ?? 'claude-opus-4-6';
50+
const envVar = getApiKeyEnvVar(getProviderForModel(model));
51+
const apiKey = process.env[envVar];
4852
if (!apiKey) {
4953
throw new Error(
50-
'ANTHROPIC_API_KEY environment variable is required for windows: findByVision'
54+
`${envVar} environment variable is required for windows: findByVision (model: ${model})`
5155
);
5256
}
5357

5458
const base64 = await this.getScreenshot();
5559
const { width: ssW, height: ssH } = getPngDimensions(base64);
5660

57-
const model = args.model ?? 'claude-opus-4-6';
5861
const raw = await callVisionLLM(base64, buildVisionPrompt(args.prompt, ssW, ssH), model, apiKey);
5962
const parsed = parseVisionCoords(raw, args.prompt);
6063

lib/mcp/tools/vision.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import {
1010
buildVisionPrompt,
1111
callVisionLLM,
1212
computeCoordMapping,
13+
getApiKeyEnvVar,
14+
getProviderForModel,
1315
parseVisionCoords,
1416
} from '../../vision-utils';
1517

@@ -52,7 +54,8 @@ export function registerVisionTools(server: McpServer, session: AppiumSession):
5254
'For "coordinates" format, locates a UI element and returns {x,y,label} with actual screen ' +
5355
'coordinates (DPI-corrected) ready to pass to click tools. ' +
5456
'For "text" format, answers a general question about the screen in plain text. ' +
55-
'Requires ANTHROPIC_API_KEY environment variable.',
57+
'Requires ANTHROPIC_API_KEY (Claude), OPENAI_API_KEY (GPT-4o / o-series), or ' +
58+
'GEMINI_API_KEY (Gemini) depending on the chosen model.',
5659
inputSchema: {
5760
prompt: z.string().min(1).describe('Question or instruction about the screenshot'),
5861
responseFormat: z.enum(['coordinates', 'text']).default('coordinates').describe(
@@ -65,15 +68,18 @@ export function registerVisionTools(server: McpServer, session: AppiumSession):
6568
},
6669
async ({ prompt, responseFormat, model }) => {
6770
try {
68-
const apiKey = process.env.ANTHROPIC_API_KEY;
71+
const visionModel = model ?? DEFAULT_MODEL;
72+
const envVar = getApiKeyEnvVar(getProviderForModel(visionModel));
73+
const apiKey = process.env[envVar];
6974
if (!apiKey) {
70-
throw new Error('ANTHROPIC_API_KEY environment variable is required for find_by_vision');
75+
throw new Error(
76+
`${envVar} environment variable is required for find_by_vision (model: ${visionModel})`
77+
);
7178
}
7279

7380
const driver = session.getDriver();
7481
const base64 = await driver.takeScreenshot() as string;
7582
const { width: ssW, height: ssH } = getPngDimensions(base64);
76-
const visionModel = model ?? DEFAULT_MODEL;
7783

7884
if (responseFormat === 'text') {
7985
const textPrompt = `Answer the following about this screenshot: "${prompt}"\nRespond with plain text.`;

lib/vision-utils.ts

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,41 @@
11
import Anthropic from '@anthropic-ai/sdk';
22

3+
export type LLMProvider = 'anthropic' | 'openai' | 'google';
4+
5+
/** Infers the LLM provider from the model identifier. */
6+
const SUPPORTED_MODELS = [
7+
'claude-* (e.g. claude-sonnet-4-6)',
8+
'gpt-* (e.g. gpt-4o)',
9+
'o1, o3, o4, o1-mini, o3-pro, …',
10+
'gemini-* (e.g. gemini-1.5-pro)',
11+
];
12+
13+
export function getProviderForModel(model: string): LLMProvider {
14+
const lower = model.toLowerCase();
15+
if (lower.startsWith('gpt-') || /^o\d/.test(lower)) {
16+
return 'openai';
17+
}
18+
if (lower.startsWith('gemini-')) {
19+
return 'google';
20+
}
21+
if (lower.startsWith('claude-')) {
22+
return 'anthropic';
23+
}
24+
throw new Error(
25+
`Unsupported model: "${model}". ` +
26+
`Supported model prefixes are:\n ${SUPPORTED_MODELS.join('\n ')}`,
27+
);
28+
}
29+
30+
/** Returns the environment variable name that holds the API key for the given provider. */
31+
export function getApiKeyEnvVar(provider: LLMProvider): string {
32+
switch (provider) {
33+
case 'openai': return 'OPENAI_API_KEY';
34+
case 'google': return 'GEMINI_API_KEY';
35+
default: return 'ANTHROPIC_API_KEY';
36+
}
37+
}
38+
339
export interface CoordMapping {
440
offsetX: number;
541
offsetY: number;
@@ -89,17 +125,12 @@ export function parseVisionCoords(
89125
return parsed;
90126
}
91127

92-
/**
93-
* Sends a base64 screenshot + text prompt to a Claude vision model and returns
94-
* the raw text response. The caller is responsible for building the prompt and
95-
* parsing the result.
96-
*/
97-
export async function callVisionLLM(
128+
async function callAnthropicVision(
98129
base64: string,
99130
textPrompt: string,
100131
model: string,
101132
apiKey: string,
102-
maxTokens = 256,
133+
maxTokens: number,
103134
): Promise<string> {
104135
const client = new Anthropic({ apiKey });
105136
const response = await client.messages.create({
@@ -118,3 +149,108 @@ export async function callVisionLLM(
118149
});
119150
return response.content.find((b) => b.type === 'text')?.text ?? '';
120151
}
152+
153+
async function callOpenAIVision(
154+
base64: string,
155+
textPrompt: string,
156+
model: string,
157+
apiKey: string,
158+
maxTokens: number,
159+
): Promise<string> {
160+
const res = await fetch('https://api.openai.com/v1/chat/completions', {
161+
method: 'POST',
162+
headers: {
163+
'Content-Type': 'application/json',
164+
'Authorization': `Bearer ${apiKey}`,
165+
},
166+
body: JSON.stringify({
167+
model,
168+
max_tokens: maxTokens,
169+
messages: [{
170+
role: 'user',
171+
content: [
172+
{ type: 'image_url', image_url: { url: `data:image/png;base64,${base64}` } },
173+
{ type: 'text', text: textPrompt },
174+
],
175+
}],
176+
}),
177+
});
178+
if (!res.ok) {
179+
const body = await res.text();
180+
let message: string;
181+
try {
182+
message = (JSON.parse(body) as { error?: { message: string } }).error?.message ?? body;
183+
} catch {
184+
message = body || res.statusText;
185+
}
186+
throw new Error(`OpenAI API error: ${message}`);
187+
}
188+
const data = await res.json() as { choices?: Array<{ message: { content: string } }> };
189+
const content = data.choices?.[0]?.message?.content;
190+
if (typeof content !== 'string') {
191+
throw new Error(`Unexpected response from OpenAI model "${model}": no text content in choices[0].message.content`);
192+
}
193+
return content;
194+
}
195+
196+
async function callGoogleVision(
197+
base64: string,
198+
textPrompt: string,
199+
model: string,
200+
apiKey: string,
201+
maxTokens: number,
202+
): Promise<string> {
203+
const url = `https://generativelanguage.googleapis.com/v1beta/models/${model}:generateContent`;
204+
const res = await fetch(url, {
205+
method: 'POST',
206+
headers: { 'Content-Type': 'application/json', 'x-goog-api-key': apiKey },
207+
body: JSON.stringify({
208+
contents: [{
209+
parts: [
210+
{ inline_data: { mime_type: 'image/png', data: base64 } },
211+
{ text: textPrompt },
212+
],
213+
}],
214+
generationConfig: { maxOutputTokens: maxTokens },
215+
}),
216+
});
217+
if (!res.ok) {
218+
const body = await res.text();
219+
let message: string;
220+
try {
221+
message = (JSON.parse(body) as { error?: { message: string } }).error?.message ?? body;
222+
} catch {
223+
message = body || res.statusText;
224+
}
225+
throw new Error(`Gemini API error: ${message}`);
226+
}
227+
const data = await res.json() as {
228+
candidates?: Array<{ content: { parts: Array<{ text: string }> } }>;
229+
};
230+
const text = data.candidates?.[0]?.content?.parts?.[0]?.text;
231+
if (typeof text !== 'string') {
232+
throw new Error(`Unexpected response from Gemini model "${model}": no text in candidates[0].content.parts[0].text`);
233+
}
234+
return text;
235+
}
236+
237+
/**
238+
* Sends a base64 screenshot + text prompt to a vision model and returns the raw
239+
* text response. Dispatches to Anthropic, OpenAI, or Google Gemini based on the
240+
* model name prefix. The caller is responsible for building the prompt and
241+
* parsing the result.
242+
*/
243+
export async function callVisionLLM(
244+
base64: string,
245+
textPrompt: string,
246+
model: string,
247+
apiKey: string,
248+
maxTokens = 256,
249+
): Promise<string> {
250+
const provider = getProviderForModel(model);
251+
switch (provider) {
252+
case 'openai': return callOpenAIVision(base64, textPrompt, model, apiKey, maxTokens);
253+
case 'google': return callGoogleVision(base64, textPrompt, model, apiKey, maxTokens);
254+
default: return callAnthropicVision(base64, textPrompt, model, apiKey, maxTokens);
255+
}
256+
}

test/commands/vision.test.ts

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
*/
44
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
55

6-
const { mockCreate } = vi.hoisted(() => ({
6+
const { mockCreate, mockFetch } = vi.hoisted(() => ({
77
mockCreate: vi.fn(),
8+
mockFetch: vi.fn(),
89
}));
910

1011
vi.mock('@anthropic-ai/sdk', () => ({
@@ -13,6 +14,8 @@ vi.mock('@anthropic-ai/sdk', () => ({
1314
})),
1415
}));
1516

17+
vi.stubGlobal('fetch', mockFetch);
18+
1619
vi.mock('../../lib/winapi/user32', () => ({
1720
getResolutionScalingFactor: vi.fn().mockReturnValue(1.0),
1821
}));
@@ -55,7 +58,7 @@ describe('executeFindByVision', () => {
5558
process.env = { ...savedEnv };
5659
});
5760

58-
it('throws when ANTHROPIC_API_KEY is not set', async () => {
61+
it('throws when ANTHROPIC_API_KEY is not set for default model', async () => {
5962
delete process.env.ANTHROPIC_API_KEY;
6063
const driver = makeMockDriver();
6164

@@ -64,6 +67,24 @@ describe('executeFindByVision', () => {
6467
).rejects.toThrow('ANTHROPIC_API_KEY');
6568
});
6669

70+
it('throws when OPENAI_API_KEY is not set for GPT model', async () => {
71+
delete process.env.OPENAI_API_KEY;
72+
const driver = makeMockDriver();
73+
74+
await expect(
75+
executeFindByVision.call(driver as any, { prompt: 'OK button', model: 'gpt-4o' })
76+
).rejects.toThrow('OPENAI_API_KEY');
77+
});
78+
79+
it('throws when GEMINI_API_KEY is not set for Gemini model', async () => {
80+
delete process.env.GEMINI_API_KEY;
81+
const driver = makeMockDriver();
82+
83+
await expect(
84+
executeFindByVision.call(driver as any, { prompt: 'OK button', model: 'gemini-2.0-flash' })
85+
).rejects.toThrow('GEMINI_API_KEY');
86+
});
87+
6788
it('returns screen coordinates for app session at 100% DPI', async () => {
6889
// At 100% DPI: rect.width === ssW, dpiScale = 1.0, isLogical = false
6990
// scaleX = ssW / rect.width = 1920 / 1920 = 1.0
@@ -207,4 +228,90 @@ describe('executeFindByVision', () => {
207228
expect(result.x).toBe(Math.round(0 + 100 * (2560 / 1920)));
208229
expect(result.y).toBe(Math.round(0 + 100 * (1440 / 1080)));
209230
});
231+
232+
describe('OpenAI provider', () => {
233+
beforeEach(() => {
234+
process.env.OPENAI_API_KEY = 'openai-test-key';
235+
});
236+
237+
it('calls OpenAI API for gpt-4o model', async () => {
238+
const driver = makeMockDriver();
239+
mockFetch.mockResolvedValue({
240+
ok: true,
241+
json: () => Promise.resolve({
242+
choices: [{ message: { content: JSON.stringify({ x: 300, y: 400, label: 'button' }) } }],
243+
}),
244+
});
245+
246+
const result = await executeFindByVision.call(driver as any, {
247+
prompt: 'button',
248+
model: 'gpt-4o',
249+
});
250+
251+
expect(mockFetch).toHaveBeenCalledWith(
252+
'https://api.openai.com/v1/chat/completions',
253+
expect.objectContaining({
254+
method: 'POST',
255+
headers: expect.objectContaining({ 'Authorization': 'Bearer openai-test-key' }),
256+
})
257+
);
258+
expect(result.label).toBe('button');
259+
});
260+
261+
it('throws on OpenAI API error response', async () => {
262+
const driver = makeMockDriver();
263+
mockFetch.mockResolvedValue({
264+
ok: false,
265+
statusText: 'Unauthorized',
266+
text: () => Promise.resolve(JSON.stringify({ error: { message: 'Invalid API key' } })),
267+
});
268+
269+
await expect(
270+
executeFindByVision.call(driver as any, { prompt: 'button', model: 'gpt-4o-mini' })
271+
).rejects.toThrow('OpenAI API error: Invalid API key');
272+
});
273+
});
274+
275+
describe('Google Gemini provider', () => {
276+
beforeEach(() => {
277+
process.env.GEMINI_API_KEY = 'gemini-test-key';
278+
});
279+
280+
it('calls Gemini API for gemini- model', async () => {
281+
const driver = makeMockDriver();
282+
mockFetch.mockResolvedValue({
283+
ok: true,
284+
json: () => Promise.resolve({
285+
candidates: [{ content: { parts: [{ text: JSON.stringify({ x: 200, y: 300, label: 'icon' }) }] } }],
286+
}),
287+
});
288+
289+
const result = await executeFindByVision.call(driver as any, {
290+
prompt: 'icon',
291+
model: 'gemini-2.0-flash',
292+
});
293+
294+
expect(mockFetch).toHaveBeenCalledWith(
295+
expect.stringContaining('generativelanguage.googleapis.com'),
296+
expect.objectContaining({
297+
method: 'POST',
298+
headers: expect.objectContaining({ 'x-goog-api-key': 'gemini-test-key' }),
299+
})
300+
);
301+
expect(result.label).toBe('icon');
302+
});
303+
304+
it('throws on Gemini API error response', async () => {
305+
const driver = makeMockDriver();
306+
mockFetch.mockResolvedValue({
307+
ok: false,
308+
statusText: 'Bad Request',
309+
text: () => Promise.resolve(JSON.stringify({ error: { message: 'API key not valid' } })),
310+
});
311+
312+
await expect(
313+
executeFindByVision.call(driver as any, { prompt: 'icon', model: 'gemini-1.5-pro' })
314+
).rejects.toThrow('Gemini API error: API key not valid');
315+
});
316+
});
210317
});

0 commit comments

Comments
 (0)