|
1 | 1 | import base64 |
2 | 2 | import logging |
| 3 | +import re |
3 | 4 | from io import BytesIO |
4 | 5 | from mimetypes import guess_type |
5 | 6 |
|
@@ -228,3 +229,211 @@ def generate_inner(self, message, dataset=None): |
228 | 229 | return self.generate_inner_vllm(message, dataset=dataset) |
229 | 230 | else: |
230 | 231 | return self.generate_inner_transformers(message, dataset=dataset) |
| 232 | + |
| 233 | + |
| 234 | +class Gemma4(BaseModel): |
| 235 | + |
| 236 | + INSTALL_REQ = False |
| 237 | + INTERLEAVE = True |
| 238 | + |
| 239 | + def __init__(self, model_path='google/gemma-4-E2B-it', **kwargs): |
| 240 | + self.use_vllm = kwargs.pop('use_vllm', False) |
| 241 | + self.limit_mm_per_prompt = kwargs.pop('limit_mm_per_prompt', 24) |
| 242 | + self.model_path = model_path |
| 243 | + |
| 244 | + try: |
| 245 | + from transformers import AutoProcessor |
| 246 | + |
| 247 | + if not self.use_vllm: |
| 248 | + try: |
| 249 | + from transformers import Gemma4ForConditionalGeneration |
| 250 | + except ImportError: |
| 251 | + try: |
| 252 | + from transformers import \ |
| 253 | + AutoModelForMultimodalLM as Gemma4ForConditionalGeneration |
| 254 | + except ImportError: |
| 255 | + from transformers import \ |
| 256 | + AutoModelForImageTextToText as Gemma4ForConditionalGeneration |
| 257 | + except Exception as e: |
| 258 | + logging.critical('Please install torch and a recent transformers version.') |
| 259 | + raise e |
| 260 | + |
| 261 | + trust_remote_code = kwargs.pop('trust_remote_code', True) |
| 262 | + if self.use_vllm: |
| 263 | + from vllm import LLM |
| 264 | + |
| 265 | + # Set tensor_parallel_size [8, 4, 2, 1] based on the number of available GPUs |
| 266 | + gpu_count = torch.cuda.device_count() |
| 267 | + if gpu_count >= 8: |
| 268 | + tp_size = 8 |
| 269 | + elif gpu_count >= 4: |
| 270 | + tp_size = 4 |
| 271 | + elif gpu_count >= 2: |
| 272 | + tp_size = 2 |
| 273 | + else: |
| 274 | + tp_size = 1 |
| 275 | + logging.info( |
| 276 | + f'Using vLLM for Gemma4 inference with {tp_size} GPUs (available: {gpu_count})' |
| 277 | + ) |
| 278 | + import os |
| 279 | + if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD') != 'spawn': |
| 280 | + logging.warning( |
| 281 | + 'VLLM_WORKER_MULTIPROC_METHOD is not set to spawn.' |
| 282 | + 'Use \'export VLLM_WORKER_MULTIPROC_METHOD=spawn\' to avoid potential multi-process issues' |
| 283 | + ) |
| 284 | + self.llm = LLM( |
| 285 | + model=model_path, |
| 286 | + max_num_seqs=4, |
| 287 | + max_model_len=16384, |
| 288 | + limit_mm_per_prompt={"image": self.limit_mm_per_prompt}, |
| 289 | + tensor_parallel_size=tp_size, |
| 290 | + gpu_memory_utilization=kwargs.get("gpu_utils", 0.9), |
| 291 | + trust_remote_code=trust_remote_code, |
| 292 | + ) |
| 293 | + # export VLLM_WORKER_MULTIPROC_METHOD=spawn |
| 294 | + self.device = 'cuda' |
| 295 | + else: |
| 296 | + model_kwargs = { |
| 297 | + 'device_map': kwargs.pop('device_map', 'cuda'), |
| 298 | + 'torch_dtype': kwargs.pop('torch_dtype', torch.bfloat16), |
| 299 | + 'trust_remote_code': trust_remote_code, |
| 300 | + } |
| 301 | + if attn_implementation := kwargs.pop('attn_implementation', None): |
| 302 | + model_kwargs['attn_implementation'] = attn_implementation |
| 303 | + |
| 304 | + self.model = Gemma4ForConditionalGeneration.from_pretrained( |
| 305 | + model_path, **model_kwargs |
| 306 | + ).eval() |
| 307 | + self.device = getattr(self.model, 'device', None) |
| 308 | + if self.device is None: |
| 309 | + self.device = next(self.model.parameters()).device |
| 310 | + |
| 311 | + self.processor = AutoProcessor.from_pretrained( |
| 312 | + model_path, trust_remote_code=trust_remote_code |
| 313 | + ) |
| 314 | + |
| 315 | + default_kwargs = { |
| 316 | + 'do_sample': False, |
| 317 | + 'max_new_tokens': 4096 |
| 318 | + } |
| 319 | + default_kwargs.update(kwargs) |
| 320 | + self.kwargs = default_kwargs |
| 321 | + self.enable_thinking = self.kwargs.pop('enable_thinking', False) |
| 322 | + |
| 323 | + @staticmethod |
| 324 | + def _load_image(image_path): |
| 325 | + with Image.open(image_path) as image: |
| 326 | + return image.convert('RGB').copy() |
| 327 | + |
| 328 | + def message2pipeline(self, message): |
| 329 | + content = [] |
| 330 | + for m in message: |
| 331 | + if m['type'] == 'text': |
| 332 | + content.append(dict(type='text', text=m['value'])) |
| 333 | + elif m['type'] == 'image': |
| 334 | + content.append(dict(type='image', image=self._load_image(m['value']))) |
| 335 | + return [dict(role='user', content=content)] |
| 336 | + |
| 337 | + def message_to_promptimg_vllm(self, message, dataset=None): |
| 338 | + processed_message = [] |
| 339 | + images = [] |
| 340 | + num_images = 0 |
| 341 | + for item in message: |
| 342 | + if item['type'] == 'text': |
| 343 | + processed_message.append({ |
| 344 | + "type": "text", |
| 345 | + "text": item['value'] |
| 346 | + }) |
| 347 | + elif item['type'] == 'image' and num_images < self.limit_mm_per_prompt: |
| 348 | + processed_message.append({"type": "image"}) |
| 349 | + images.append(self._load_image(item['value'])) |
| 350 | + num_images += 1 |
| 351 | + if num_images >= self.limit_mm_per_prompt: |
| 352 | + logging.warning( |
| 353 | + f"Number of images exceeds the limit of {self.limit_mm_per_prompt}." |
| 354 | + f"Only the first {self.limit_mm_per_prompt} images will be used." |
| 355 | + ) |
| 356 | + return processed_message, images |
| 357 | + |
| 358 | + @staticmethod |
| 359 | + def extract_response_for_eval(response): |
| 360 | + if not isinstance(response, str): |
| 361 | + return response |
| 362 | + |
| 363 | + if '</think>' in response: |
| 364 | + response = response.split('</think>')[-1] |
| 365 | + if '</thinking>' in response: |
| 366 | + response = response.split('</thinking>')[-1] |
| 367 | + |
| 368 | + response = re.sub( |
| 369 | + r'<\|channel\>thought\n.*?<channel\|>', |
| 370 | + '', |
| 371 | + response, |
| 372 | + flags=re.DOTALL |
| 373 | + ) |
| 374 | + final_match = re.search( |
| 375 | + r'<\|channel\>final\n(?P<answer>.*?)(?:<channel\|>|$)', |
| 376 | + response, |
| 377 | + flags=re.DOTALL |
| 378 | + ) |
| 379 | + if final_match: |
| 380 | + response = final_match.group('answer') |
| 381 | + |
| 382 | + return response.strip() |
| 383 | + |
| 384 | + def generate_inner_transformers(self, message, dataset=None): |
| 385 | + messages = self.message2pipeline(message) |
| 386 | + inputs = self.processor.apply_chat_template( |
| 387 | + messages, add_generation_prompt=True, tokenize=True, |
| 388 | + return_dict=True, return_tensors='pt', |
| 389 | + enable_thinking=self.enable_thinking, |
| 390 | + ).to(self.device, dtype=torch.bfloat16) |
| 391 | + |
| 392 | + input_len = inputs['input_ids'].shape[-1] |
| 393 | + |
| 394 | + with torch.inference_mode(): |
| 395 | + generation = self.model.generate(**inputs, **self.kwargs) |
| 396 | + generation = generation[0][input_len:] |
| 397 | + |
| 398 | + decoded = self.processor.decode(generation, skip_special_tokens=False) |
| 399 | + if hasattr(self.processor, 'parse_response'): |
| 400 | + decoded = self.processor.parse_response(decoded) |
| 401 | + if isinstance(decoded, dict): |
| 402 | + decoded = decoded.get('answer', decoded.get('response', str(decoded))) |
| 403 | + elif isinstance(decoded, tuple): |
| 404 | + decoded = decoded[-1] |
| 405 | + return self.extract_response_for_eval(decoded) |
| 406 | + |
| 407 | + def generate_inner_vllm(self, message, dataset=None): |
| 408 | + from vllm import SamplingParams |
| 409 | + |
| 410 | + prompt, images = self.message_to_promptimg_vllm(message, dataset=dataset) |
| 411 | + messages = [ |
| 412 | + {'role': 'user', 'content': prompt} |
| 413 | + ] |
| 414 | + prompt = self.processor.apply_chat_template( |
| 415 | + messages, |
| 416 | + tokenize=False, |
| 417 | + add_generation_prompt=True, |
| 418 | + enable_thinking=self.enable_thinking, |
| 419 | + ) |
| 420 | + |
| 421 | + sampling_params = SamplingParams(temperature=0.0, |
| 422 | + max_tokens=self.kwargs['max_new_tokens']) |
| 423 | + outputs = self.llm.generate( |
| 424 | + { |
| 425 | + "prompt": prompt, |
| 426 | + "multi_modal_data": { |
| 427 | + "image": images |
| 428 | + }, |
| 429 | + }, |
| 430 | + sampling_params=sampling_params |
| 431 | + ) |
| 432 | + for o in outputs: |
| 433 | + generated_text = o.outputs[0].text |
| 434 | + return self.extract_response_for_eval(generated_text) |
| 435 | + |
| 436 | + def generate_inner(self, message, dataset=None): |
| 437 | + if self.use_vllm: |
| 438 | + return self.generate_inner_vllm(message, dataset=dataset) |
| 439 | + return self.generate_inner_transformers(message, dataset=dataset) |
0 commit comments