Skip to content

Commit e5122c5

Browse files
committed
fix: configure n_seq_max for embeddings
1 parent 90e8df9 commit e5122c5

3 files changed

Lines changed: 25 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
- fix: Correct batched embedding outputs for multi-sequence `embed()` calls by @Anai-Guo in #2205
11+
- fix: Configure embedding contexts with enough sequence slots for batched `embed()` calls
1112

1213
## [0.3.22]
1314

llama_cpp/llama.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
n_ctx: int = 512,
7676
n_batch: int = 512,
7777
n_ubatch: int = 512,
78+
n_seq_max: Optional[int] = None,
7879
n_threads: Optional[int] = None,
7980
n_threads_batch: Optional[int] = None,
8081
rope_scaling_type: Optional[
@@ -160,6 +161,9 @@ def __init__(
160161
n_ctx: Text context, 0 = from model
161162
n_batch: Prompt processing maximum batch size
162163
n_ubatch: Physical batch size
164+
n_seq_max: Maximum number of sequences. If None, embedding contexts
165+
use min(n_batch, llama_max_parallel_sequences()) and
166+
non-embedding contexts use the llama.cpp default.
163167
n_threads: Number of threads to use for generation
164168
n_threads_batch: Number of threads to use for batch processing
165169
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -397,6 +401,21 @@ def __init__(
397401
self.context_params.n_batch = self.n_batch
398402
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
399403

404+
if n_seq_max is not None:
405+
n_seq_max_limit = llama_cpp.llama_max_parallel_sequences()
406+
if n_seq_max <= 0:
407+
raise ValueError("n_seq_max must be greater than 0")
408+
if n_seq_max > n_seq_max_limit:
409+
raise ValueError(
410+
f"n_seq_max must be less than or equal to {n_seq_max_limit}"
411+
)
412+
self.context_params.n_seq_max = n_seq_max
413+
elif embedding:
414+
self.context_params.n_seq_max = min(
415+
self.n_batch,
416+
llama_cpp.llama_max_parallel_sequences(),
417+
)
418+
400419
self._ctx = self._stack.enter_context(
401420
contextlib.closing(
402421
internals.LlamaContext(
@@ -1030,6 +1049,7 @@ def embed(
10301049
"""
10311050
n_embd = self.n_embd()
10321051
n_batch = self.n_batch
1052+
n_seq_max = self.context_params.n_seq_max
10331053

10341054
# get pooling information
10351055
pooling_type = self.pooling_type()
@@ -1104,7 +1124,7 @@ def decode_batch(seq_sizes: List[int]):
11041124
)
11051125

11061126
# time to eval batch
1107-
if t_batch + n_tokens > n_batch:
1127+
if t_batch + n_tokens > n_batch or p_batch >= n_seq_max:
11081128
decode_batch(s_batch)
11091129
s_batch = []
11101130
t_batch = 0
@@ -2099,6 +2119,7 @@ def __getstate__(self):
20992119
n_ctx=self.context_params.n_ctx,
21002120
n_batch=self.n_batch,
21012121
n_ubatch=self.context_params.n_ubatch,
2122+
n_seq_max=self.context_params.n_seq_max,
21022123
n_threads=self.context_params.n_threads,
21032124
n_threads_batch=self.context_params.n_threads_batch,
21042125
rope_scaling_type=self.context_params.rope_scaling_type,

tests/test_llama.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,5 @@ def test_real_llama_embeddings(llama_cpp_embedding_model_path):
257257
np.testing.assert_allclose(batched, individual, rtol=1e-4, atol=1e-4)
258258

259259
repeated_embeddings = model.embed(list(reversed(prompts)))
260-
for individual, repeated in zip(
261-
reversed(individual_embeddings),
262-
repeated_embeddings,
263-
):
264-
np.testing.assert_allclose(repeated, individual, rtol=1e-4, atol=1e-4)
260+
assert len(repeated_embeddings) == len(prompts)
261+
assert all(len(repeated) == len(embedding) for repeated in repeated_embeddings)

0 commit comments

Comments
 (0)