Skip to content

Commit 10f08de

Browse files
[Model] Add ColPali late interaction model for multi-modal retrieval (vllm-project#36818)
Signed-off-by: Nikita Sukharev <kaonael@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent 5e1a373 commit 10f08de

9 files changed

Lines changed: 634 additions & 0 deletions

File tree

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ The following table lists those that are tested in vLLM.
828828
| ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- |
829829
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
830830
| `ColModernVBertForRetrieval` | ColModernVBERT | T / I | `ModernVBERT/colmodernvbert-merged` | | |
831+
| `ColPaliForRetrieval` | ColPali | T / I | `vidore/colpali-v1.3-hf` | | |
831832
| `LlamaNemotronVLModel` | Llama Nemotron Embedding + SigLIP | T + I | `nvidia/llama-nemotron-embed-vl-1b-v2` | | |
832833
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
833834
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for ColPali late interaction model for multi-modal retrieval.
4+
5+
ColPali is a multi-vector retrieval model based on PaliGemma backbone
6+
(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim).
7+
It produces per-token embeddings for both text and image inputs.
8+
"""
9+
10+
import base64
11+
from io import BytesIO
12+
13+
import pytest
14+
import torch
15+
from PIL import Image
16+
17+
from vllm.entrypoints.chat_utils import (
18+
ChatCompletionContentPartImageParam,
19+
ChatCompletionContentPartTextParam,
20+
)
21+
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
22+
23+
from ....conftest import VllmRunner
24+
25+
MODELS = [
26+
"vidore/colpali-v1.3-hf",
27+
]
28+
29+
EMBED_DIMS = {
30+
"vidore/colpali-v1.3-hf": 128,
31+
}
32+
33+
TEXT_QUERIES = [
34+
"What is the capital of France?",
35+
"Describe the contents of the document.",
36+
]
37+
38+
TEXT_DOCUMENTS = [
39+
"The capital of France is Paris.",
40+
"This document contains important financial data.",
41+
]
42+
43+
DTYPE = "half"
44+
GPU_MEMORY_UTILIZATION = 0.7
45+
46+
47+
def _make_base64_image(
48+
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
49+
) -> str:
50+
"""Create a small solid-color PNG image and return its base64 data URI."""
51+
img = Image.new("RGB", (width, height), color)
52+
buf = BytesIO()
53+
img.save(buf, format="PNG")
54+
b64 = base64.b64encode(buf.getvalue()).decode()
55+
return f"data:image/png;base64,{b64}"
56+
57+
58+
def _make_image_mm_param(
59+
image_uri: str,
60+
text: str | None = None,
61+
) -> ScoreMultiModalParam:
62+
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
63+
content: list = [
64+
ChatCompletionContentPartImageParam(
65+
type="image_url",
66+
image_url={"url": image_uri},
67+
),
68+
]
69+
if text is not None:
70+
content.append(
71+
ChatCompletionContentPartTextParam(type="text", text=text),
72+
)
73+
return ScoreMultiModalParam(content=content)
74+
75+
76+
def _run_token_embed_test(
77+
vllm_runner: type[VllmRunner],
78+
model: str,
79+
*,
80+
dtype: str,
81+
) -> None:
82+
"""Verify per-token embedding shape and L2 normalization."""
83+
with vllm_runner(
84+
model,
85+
runner="pooling",
86+
dtype=dtype,
87+
max_model_len=4096,
88+
enforce_eager=True,
89+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
90+
) as vllm_model:
91+
outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
92+
93+
assert len(outputs) == 1
94+
emb = torch.tensor(outputs[0])
95+
# Token embeddings should be 2D: [num_tokens, embed_dim]
96+
assert emb.dim() == 2
97+
assert emb.shape[1] == EMBED_DIMS[model]
98+
assert emb.shape[0] > 1
99+
100+
# Verify L2 normalization
101+
norms = torch.norm(emb, p=2, dim=-1)
102+
torch.testing.assert_close(
103+
norms,
104+
torch.ones_like(norms),
105+
rtol=1e-2,
106+
atol=1e-2,
107+
)
108+
109+
110+
def _run_late_interaction_test(
111+
vllm_runner: type[VllmRunner],
112+
model: str,
113+
*,
114+
dtype: str,
115+
) -> None:
116+
"""Verify MaxSim scoring matches manual computation."""
117+
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
118+
119+
with vllm_runner(
120+
model,
121+
runner="pooling",
122+
dtype=dtype,
123+
max_model_len=4096,
124+
enforce_eager=True,
125+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
126+
) as vllm_model:
127+
q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
128+
d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])
129+
130+
q_emb = torch.tensor(q_outputs[0])
131+
d_emb = torch.tensor(d_outputs[0])
132+
133+
manual_score = compute_maxsim_score(q_emb, d_emb).item()
134+
135+
vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])
136+
137+
assert len(vllm_scores) == 1
138+
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
139+
140+
141+
def _run_relevance_test(
142+
vllm_runner: type[VllmRunner],
143+
model: str,
144+
*,
145+
dtype: str,
146+
) -> None:
147+
"""Verify that relevant documents score higher than irrelevant ones."""
148+
query = "What is machine learning?"
149+
documents = [
150+
"Machine learning is a subset of artificial intelligence.",
151+
"The weather forecast shows rain tomorrow.",
152+
"Deep learning uses neural networks for complex tasks.",
153+
]
154+
155+
with vllm_runner(
156+
model,
157+
runner="pooling",
158+
dtype=dtype,
159+
max_model_len=4096,
160+
enforce_eager=True,
161+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
162+
) as vllm_model:
163+
scores = vllm_model.score(query, documents)
164+
165+
assert len(scores) == 3
166+
assert scores[0] > scores[1], "ML doc should score higher than weather doc"
167+
assert scores[2] > scores[1], "DL doc should score higher than weather doc"
168+
169+
170+
@pytest.mark.parametrize("model", MODELS)
171+
@pytest.mark.parametrize("dtype", [DTYPE])
172+
def test_colpali_token_embed(
173+
vllm_runner,
174+
model: str,
175+
dtype: str,
176+
) -> None:
177+
_run_token_embed_test(vllm_runner, model, dtype=dtype)
178+
179+
180+
@pytest.mark.parametrize("model", MODELS)
181+
@pytest.mark.parametrize("dtype", [DTYPE])
182+
def test_colpali_late_interaction_scoring(
183+
vllm_runner,
184+
model: str,
185+
dtype: str,
186+
) -> None:
187+
_run_late_interaction_test(vllm_runner, model, dtype=dtype)
188+
189+
190+
@pytest.mark.parametrize("model", MODELS)
191+
@pytest.mark.parametrize("dtype", [DTYPE])
192+
def test_colpali_relevance_ordering(
193+
vllm_runner,
194+
model: str,
195+
dtype: str,
196+
) -> None:
197+
_run_relevance_test(vllm_runner, model, dtype=dtype)
198+
199+
200+
# ── Multimodal scoring tests ────────────────────────────────
201+
202+
203+
def _run_multimodal_text_query_image_docs_test(
204+
vllm_runner: type[VllmRunner],
205+
model: str,
206+
*,
207+
dtype: str,
208+
) -> None:
209+
"""Score a text query against image documents via the multimodal path."""
210+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
211+
blue_image = _make_base64_image(64, 64, color=(0, 0, 255))
212+
213+
query = "Describe the red object"
214+
image_docs = [
215+
_make_image_mm_param(red_image),
216+
_make_image_mm_param(blue_image),
217+
]
218+
219+
with vllm_runner(
220+
model,
221+
runner="pooling",
222+
dtype=dtype,
223+
max_model_len=4096,
224+
enforce_eager=True,
225+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
226+
) as vllm_model:
227+
scores = vllm_model.llm.score(query, image_docs)
228+
229+
assert len(scores) == 2
230+
for s in scores:
231+
assert isinstance(s.outputs.score, float)
232+
233+
234+
def _run_multimodal_mixed_docs_test(
235+
vllm_runner: type[VllmRunner],
236+
model: str,
237+
*,
238+
dtype: str,
239+
) -> None:
240+
"""Score a text query against a mix of text and image documents."""
241+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
242+
243+
query = "What is the capital of France?"
244+
documents: list = [
245+
"The capital of France is Paris.",
246+
_make_image_mm_param(red_image),
247+
]
248+
249+
with vllm_runner(
250+
model,
251+
runner="pooling",
252+
dtype=dtype,
253+
max_model_len=4096,
254+
enforce_eager=True,
255+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
256+
) as vllm_model:
257+
scores = vllm_model.llm.score(query, documents)
258+
259+
assert len(scores) == 2
260+
for s in scores:
261+
assert isinstance(s.outputs.score, float)
262+
# Text document about France should score higher than a random image
263+
assert scores[0].outputs.score > scores[1].outputs.score
264+
265+
266+
def _run_multimodal_image_query_text_docs_test(
267+
vllm_runner: type[VllmRunner],
268+
model: str,
269+
*,
270+
dtype: str,
271+
) -> None:
272+
"""Score an image query against text documents."""
273+
red_image = _make_base64_image(64, 64, color=(255, 0, 0))
274+
image_query = _make_image_mm_param(red_image, text="red color")
275+
276+
documents = [
277+
"A bright red sports car.",
278+
"The weather forecast shows rain tomorrow.",
279+
]
280+
281+
with vllm_runner(
282+
model,
283+
runner="pooling",
284+
dtype=dtype,
285+
max_model_len=4096,
286+
enforce_eager=True,
287+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
288+
) as vllm_model:
289+
scores = vllm_model.llm.score(image_query, documents)
290+
291+
assert len(scores) == 2
292+
for s in scores:
293+
assert isinstance(s.outputs.score, float)
294+
295+
296+
@pytest.mark.parametrize("model", MODELS)
297+
@pytest.mark.parametrize("dtype", [DTYPE])
298+
def test_colpali_multimodal_text_query_image_docs(
299+
vllm_runner,
300+
model: str,
301+
dtype: str,
302+
) -> None:
303+
_run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)
304+
305+
306+
@pytest.mark.parametrize("model", MODELS)
307+
@pytest.mark.parametrize("dtype", [DTYPE])
308+
def test_colpali_multimodal_mixed_docs(
309+
vllm_runner,
310+
model: str,
311+
dtype: str,
312+
) -> None:
313+
_run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)
314+
315+
316+
@pytest.mark.parametrize("model", MODELS)
317+
@pytest.mark.parametrize("dtype", [DTYPE])
318+
def test_colpali_multimodal_image_query_text_docs(
319+
vllm_runner,
320+
model: str,
321+
dtype: str,
322+
) -> None:
323+
_run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def check_available_online(
631631
"ColModernVBertForRetrieval": _HfExamplesInfo(
632632
"ModernVBERT/colmodernvbert-merged",
633633
),
634+
"ColPaliForRetrieval": _HfExamplesInfo("vidore/colpali-v1.3-hf"),
634635
"ColQwen3": _HfExamplesInfo(
635636
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
636637
),

0 commit comments

Comments
 (0)