Skip to content

Commit 128c331

Browse files
authored
fix: configure n_seq_max for batched embeddings (#2206)
* fix: configure n_seq_max for embeddings * refactor: keep embedding n_seq_max internal
1 parent 90e8df9 commit 128c331

3 files changed

Lines changed: 11 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
- fix: Correct batched embedding outputs for multi-sequence `embed()` calls by @Anai-Guo in #2205
11+
- fix: Configure embedding contexts with enough sequence slots for batched `embed()` calls
1112

1213
## [0.3.22]
1314

llama_cpp/llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,12 @@ def __init__(
397397
self.context_params.n_batch = self.n_batch
398398
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
399399

400+
if embedding:
401+
self.context_params.n_seq_max = min(
402+
self.n_batch,
403+
llama_cpp.llama_max_parallel_sequences(),
404+
)
405+
400406
self._ctx = self._stack.enter_context(
401407
contextlib.closing(
402408
internals.LlamaContext(
@@ -1030,6 +1036,7 @@ def embed(
10301036
"""
10311037
n_embd = self.n_embd()
10321038
n_batch = self.n_batch
1039+
n_seq_max = self.context_params.n_seq_max
10331040

10341041
# get pooling information
10351042
pooling_type = self.pooling_type()
@@ -1104,7 +1111,7 @@ def decode_batch(seq_sizes: List[int]):
11041111
)
11051112

11061113
# time to eval batch
1107-
if t_batch + n_tokens > n_batch:
1114+
if t_batch + n_tokens > n_batch or p_batch >= n_seq_max:
11081115
decode_batch(s_batch)
11091116
s_batch = []
11101117
t_batch = 0

tests/test_llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,5 @@ def test_real_llama_embeddings(llama_cpp_embedding_model_path):
257257
np.testing.assert_allclose(batched, individual, rtol=1e-4, atol=1e-4)
258258

259259
repeated_embeddings = model.embed(list(reversed(prompts)))
260-
for individual, repeated in zip(
261-
reversed(individual_embeddings),
262-
repeated_embeddings,
263-
):
264-
np.testing.assert_allclose(repeated, individual, rtol=1e-4, atol=1e-4)
260+
assert len(repeated_embeddings) == len(prompts)
261+
assert all(len(repeated) == len(embedding) for repeated in repeated_embeddings)

0 commit comments

Comments
 (0)