Skip to content

Commit fdda9bc

Browse files
committed
Add chat template for gemma models
For multimodal models, we need to change the user/assistant roles and add the proper start and end tokens
1 parent 1b1a320 commit fdda9bc

1 file changed

Lines changed: 94 additions & 0 deletions

File tree

llama_cpp/llama_chat_format.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3229,6 +3229,62 @@ def from_pretrained(
32293229
)
32303230

32313231

3232+
class GemmaChatHandler(Llava15ChatHandler):
3233+
"""Chat handler for Gemma-based multimodal models (e.g., PaliGemma, MedGemma).
3234+
3235+
Gemma models use <start_of_turn>/<end_of_turn> control tokens instead of
3236+
the LLaVA-style USER:/ASSISTANT: format. The text-only 'gemma' chat format
3237+
is already registered (see format_gemma), but multimodal Gemma models that
3238+
require a Llava-style vision pipeline need a dedicated handler so the
3239+
correct chat template is applied when chat_handler takes precedence over
3240+
chat_format in the resolution order.
3241+
3242+
See: https://ai.google.dev/gemma/docs/formatting
3243+
"""
3244+
3245+
DEFAULT_SYSTEM_MESSAGE = None # Gemma models do not natively support a system role
3246+
3247+
CHAT_FORMAT = (
3248+
"{% for message in messages %}"
3249+
# System messages are folded into a user turn (Gemma has no system role)
3250+
"{% if message.role == 'system' %}"
3251+
"<start_of_turn>user\n{{ message.content }}<end_of_turn>\n"
3252+
"{% endif %}"
3253+
# User message (handles both plain string and multimodal content list)
3254+
"{% if message.role == 'user' %}"
3255+
"<start_of_turn>user\n"
3256+
"{% if message.content is string %}"
3257+
"{{ message.content }}"
3258+
"{% endif %}"
3259+
"{% if message.content is iterable and message.content is not string %}"
3260+
# Emit image tokens first
3261+
"{% for content in message.content %}"
3262+
"{% if content.type == 'image_url' and content.image_url is string %}"
3263+
"{{ content.image_url }}"
3264+
"{% endif %}"
3265+
"{% if content.type == 'image_url' and content.image_url is mapping %}"
3266+
"{{ content.image_url.url }}"
3267+
"{% endif %}"
3268+
"{% endfor %}"
3269+
# Then emit text tokens
3270+
"{% for content in message.content %}"
3271+
"{% if content.type == 'text' %}"
3272+
"{{ content.text }}"
3273+
"{% endif %}"
3274+
"{% endfor %}"
3275+
"{% endif %}"
3276+
"<end_of_turn>\n"
3277+
"{% endif %}"
3278+
# Assistant message
3279+
"{% if message.role == 'assistant' and message.content is not none %}"
3280+
"<start_of_turn>model\n{{ message.content }}<end_of_turn>\n"
3281+
"{% endif %}"
3282+
"{% endfor %}"
3283+
# Generation prompt
3284+
"{% if add_generation_prompt %}"
3285+
"<start_of_turn>model\n"
3286+
"{% endif %}"
3287+
)
32323288
class ObsidianChatHandler(Llava15ChatHandler):
32333289
# Prompt Format
32343290
# The model followed ChatML format. However, with ### as the seperator
@@ -3581,6 +3637,44 @@ def __call__(self, **kwargs):
35813637
return super().__call__(**kwargs)
35823638

35833639

3640+
class MultimodalGemmaChatHandler(Llava15ChatHandler):
3641+
DEFAULT_SYSTEM_MESSAGE: Optional[str] = None
3642+
3643+
CHAT_FORMAT = (
3644+
"{% for message in messages %}"
3645+
"{% if message.role == 'user' %}"
3646+
"<start_of_turn>user\n"
3647+
"{% if message.content is string %}"
3648+
"{{ message.content }}"
3649+
"{% endif %}"
3650+
"{% if message.content is iterable %}"
3651+
"{% for content in message.content %}"
3652+
"{% if content.type == 'image_url' and content.image_url is string %}"
3653+
"{{ content.image_url }}"
3654+
"{% endif %}"
3655+
"{% if content.type == 'image_url' and content.image_url is mapping %}"
3656+
"{{ content.image_url.url }}"
3657+
"{% endif %}"
3658+
"{% endfor %}"
3659+
"{% for content in message.content %}"
3660+
"{% if content.type == 'text' %}"
3661+
"{{ content.text }}"
3662+
"{% endif %}"
3663+
"{% endfor %}"
3664+
"{% endif %}"
3665+
"<end_of_turn>\n"
3666+
"{% endif %}"
3667+
"{% if message.role == 'assistant' and message.content is not none %}"
3668+
"<start_of_turn>model\n"
3669+
"{{ message.content }}<end_of_turn>\n"
3670+
"{% endif %}"
3671+
"{% endfor %}"
3672+
"{% if add_generation_prompt %}"
3673+
"<start_of_turn>model\n"
3674+
"{% endif %}"
3675+
)
3676+
3677+
35843678
@register_chat_completion_handler("chatml-function-calling")
35853679
def chatml_function_calling(
35863680
llama: llama.Llama,

0 commit comments

Comments
 (0)