Skip to content

Commit dee57a3

Browse files
committed
Add Gemma 4 support
1 parent 0916343 commit dee57a3

3 files changed

Lines changed: 216 additions & 2 deletions

File tree

vlmeval/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2213,7 +2213,12 @@
22132213

22142214
'Gemma3-4B': partial(vlm.Gemma3, model_path='google/gemma-3-4b-it'),
22152215
'Gemma3-12B': partial(vlm.Gemma3, model_path='google/gemma-3-12b-it'),
2216-
'Gemma3-27B': partial(vlm.Gemma3, model_path='google/gemma-3-27b-it')
2216+
'Gemma3-27B': partial(vlm.Gemma3, model_path='google/gemma-3-27b-it'),
2217+
2218+
'Gemma4-E2B-it': partial(vlm.Gemma4, model_path='google/gemma-4-E2B-it'),
2219+
'Gemma4-E4B-it': partial(vlm.Gemma4, model_path='google/gemma-4-E4B-it'),
2220+
'Gemma4-31B-it': partial(vlm.Gemma4, model_path='google/gemma-4-31B-it'),
2221+
'Gemma4-26B-A4B-it': partial(vlm.Gemma4, model_path='google/gemma-4-26B-A4B-it')
22172222
}
22182223

22192224
aguvis_series = {

vlmeval/vlm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .emu import Emu, Emu3_chat, Emu3_gen
2121
from .falcon_vlm import Falcon2VLM
2222
from .flash_vl import FlashVL
23-
from .gemma import Gemma3, PaliGemma
23+
from .gemma import Gemma3, Gemma4, PaliGemma
2424
from .granite_docling import DOCLING
2525
from .granite_vision import GraniteVision3
2626
from .h2ovl_mississippi import H2OVLChat

vlmeval/vlm/gemma.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import logging
3+
import re
34
from io import BytesIO
45
from mimetypes import guess_type
56

@@ -228,3 +229,211 @@ def generate_inner(self, message, dataset=None):
228229
return self.generate_inner_vllm(message, dataset=dataset)
229230
else:
230231
return self.generate_inner_transformers(message, dataset=dataset)
232+
233+
234+
class Gemma4(BaseModel):
235+
236+
INSTALL_REQ = False
237+
INTERLEAVE = True
238+
239+
def __init__(self, model_path='google/gemma-4-E2B-it', **kwargs):
240+
self.use_vllm = kwargs.pop('use_vllm', False)
241+
self.limit_mm_per_prompt = kwargs.pop('limit_mm_per_prompt', 24)
242+
self.model_path = model_path
243+
244+
try:
245+
from transformers import AutoProcessor
246+
247+
if not self.use_vllm:
248+
try:
249+
from transformers import Gemma4ForConditionalGeneration
250+
except ImportError:
251+
try:
252+
from transformers import \
253+
AutoModelForMultimodalLM as Gemma4ForConditionalGeneration
254+
except ImportError:
255+
from transformers import \
256+
AutoModelForImageTextToText as Gemma4ForConditionalGeneration
257+
except Exception as e:
258+
logging.critical('Please install torch and a recent transformers version.')
259+
raise e
260+
261+
trust_remote_code = kwargs.pop('trust_remote_code', True)
262+
if self.use_vllm:
263+
from vllm import LLM
264+
265+
# Set tensor_parallel_size [8, 4, 2, 1] based on the number of available GPUs
266+
gpu_count = torch.cuda.device_count()
267+
if gpu_count >= 8:
268+
tp_size = 8
269+
elif gpu_count >= 4:
270+
tp_size = 4
271+
elif gpu_count >= 2:
272+
tp_size = 2
273+
else:
274+
tp_size = 1
275+
logging.info(
276+
f'Using vLLM for Gemma4 inference with {tp_size} GPUs (available: {gpu_count})'
277+
)
278+
import os
279+
if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD') != 'spawn':
280+
logging.warning(
281+
'VLLM_WORKER_MULTIPROC_METHOD is not set to spawn.'
282+
'Use \'export VLLM_WORKER_MULTIPROC_METHOD=spawn\' to avoid potential multi-process issues'
283+
)
284+
self.llm = LLM(
285+
model=model_path,
286+
max_num_seqs=4,
287+
max_model_len=16384,
288+
limit_mm_per_prompt={"image": self.limit_mm_per_prompt},
289+
tensor_parallel_size=tp_size,
290+
gpu_memory_utilization=kwargs.get("gpu_utils", 0.9),
291+
trust_remote_code=trust_remote_code,
292+
)
293+
# export VLLM_WORKER_MULTIPROC_METHOD=spawn
294+
self.device = 'cuda'
295+
else:
296+
model_kwargs = {
297+
'device_map': kwargs.pop('device_map', 'cuda'),
298+
'torch_dtype': kwargs.pop('torch_dtype', torch.bfloat16),
299+
'trust_remote_code': trust_remote_code,
300+
}
301+
if attn_implementation := kwargs.pop('attn_implementation', None):
302+
model_kwargs['attn_implementation'] = attn_implementation
303+
304+
self.model = Gemma4ForConditionalGeneration.from_pretrained(
305+
model_path, **model_kwargs
306+
).eval()
307+
self.device = getattr(self.model, 'device', None)
308+
if self.device is None:
309+
self.device = next(self.model.parameters()).device
310+
311+
self.processor = AutoProcessor.from_pretrained(
312+
model_path, trust_remote_code=trust_remote_code
313+
)
314+
315+
default_kwargs = {
316+
'do_sample': False,
317+
'max_new_tokens': 4096
318+
}
319+
default_kwargs.update(kwargs)
320+
self.kwargs = default_kwargs
321+
self.enable_thinking = self.kwargs.pop('enable_thinking', False)
322+
323+
@staticmethod
324+
def _load_image(image_path):
325+
with Image.open(image_path) as image:
326+
return image.convert('RGB').copy()
327+
328+
def message2pipeline(self, message):
329+
content = []
330+
for m in message:
331+
if m['type'] == 'text':
332+
content.append(dict(type='text', text=m['value']))
333+
elif m['type'] == 'image':
334+
content.append(dict(type='image', image=self._load_image(m['value'])))
335+
return [dict(role='user', content=content)]
336+
337+
def message_to_promptimg_vllm(self, message, dataset=None):
338+
processed_message = []
339+
images = []
340+
num_images = 0
341+
for item in message:
342+
if item['type'] == 'text':
343+
processed_message.append({
344+
"type": "text",
345+
"text": item['value']
346+
})
347+
elif item['type'] == 'image' and num_images < self.limit_mm_per_prompt:
348+
processed_message.append({"type": "image"})
349+
images.append(self._load_image(item['value']))
350+
num_images += 1
351+
if num_images >= self.limit_mm_per_prompt:
352+
logging.warning(
353+
f"Number of images exceeds the limit of {self.limit_mm_per_prompt}."
354+
f"Only the first {self.limit_mm_per_prompt} images will be used."
355+
)
356+
return processed_message, images
357+
358+
@staticmethod
359+
def extract_response_for_eval(response):
360+
if not isinstance(response, str):
361+
return response
362+
363+
if '</think>' in response:
364+
response = response.split('</think>')[-1]
365+
if '</thinking>' in response:
366+
response = response.split('</thinking>')[-1]
367+
368+
response = re.sub(
369+
r'<\|channel\>thought\n.*?<channel\|>',
370+
'',
371+
response,
372+
flags=re.DOTALL
373+
)
374+
final_match = re.search(
375+
r'<\|channel\>final\n(?P<answer>.*?)(?:<channel\|>|$)',
376+
response,
377+
flags=re.DOTALL
378+
)
379+
if final_match:
380+
response = final_match.group('answer')
381+
382+
return response.strip()
383+
384+
def generate_inner_transformers(self, message, dataset=None):
385+
messages = self.message2pipeline(message)
386+
inputs = self.processor.apply_chat_template(
387+
messages, add_generation_prompt=True, tokenize=True,
388+
return_dict=True, return_tensors='pt',
389+
enable_thinking=self.enable_thinking,
390+
).to(self.device, dtype=torch.bfloat16)
391+
392+
input_len = inputs['input_ids'].shape[-1]
393+
394+
with torch.inference_mode():
395+
generation = self.model.generate(**inputs, **self.kwargs)
396+
generation = generation[0][input_len:]
397+
398+
decoded = self.processor.decode(generation, skip_special_tokens=False)
399+
if hasattr(self.processor, 'parse_response'):
400+
decoded = self.processor.parse_response(decoded)
401+
if isinstance(decoded, dict):
402+
decoded = decoded.get('answer', decoded.get('response', str(decoded)))
403+
elif isinstance(decoded, tuple):
404+
decoded = decoded[-1]
405+
return self.extract_response_for_eval(decoded)
406+
407+
def generate_inner_vllm(self, message, dataset=None):
408+
from vllm import SamplingParams
409+
410+
prompt, images = self.message_to_promptimg_vllm(message, dataset=dataset)
411+
messages = [
412+
{'role': 'user', 'content': prompt}
413+
]
414+
prompt = self.processor.apply_chat_template(
415+
messages,
416+
tokenize=False,
417+
add_generation_prompt=True,
418+
enable_thinking=self.enable_thinking,
419+
)
420+
421+
sampling_params = SamplingParams(temperature=0.0,
422+
max_tokens=self.kwargs['max_new_tokens'])
423+
outputs = self.llm.generate(
424+
{
425+
"prompt": prompt,
426+
"multi_modal_data": {
427+
"image": images
428+
},
429+
},
430+
sampling_params=sampling_params
431+
)
432+
for o in outputs:
433+
generated_text = o.outputs[0].text
434+
return self.extract_response_for_eval(generated_text)
435+
436+
def generate_inner(self, message, dataset=None):
437+
if self.use_vllm:
438+
return self.generate_inner_vllm(message, dataset=dataset)
439+
return self.generate_inner_transformers(message, dataset=dataset)

0 commit comments

Comments
 (0)