Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

- fix: Correct batched embedding outputs for multi-sequence `embed()` calls by @Anai-Guo in #2205
- fix: Configure embedding contexts with enough sequence slots for batched `embed()` calls

## [0.3.22]

Expand Down
9 changes: 8 additions & 1 deletion llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ def __init__(
self.context_params.n_batch = self.n_batch
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)

if embedding:
self.context_params.n_seq_max = min(
self.n_batch,
llama_cpp.llama_max_parallel_sequences(),
)

self._ctx = self._stack.enter_context(
contextlib.closing(
internals.LlamaContext(
Expand Down Expand Up @@ -1030,6 +1036,7 @@ def embed(
"""
n_embd = self.n_embd()
n_batch = self.n_batch
n_seq_max = self.context_params.n_seq_max

# get pooling information
pooling_type = self.pooling_type()
Expand Down Expand Up @@ -1104,7 +1111,7 @@ def decode_batch(seq_sizes: List[int]):
)

# time to eval batch
if t_batch + n_tokens > n_batch:
if t_batch + n_tokens > n_batch or p_batch >= n_seq_max:
decode_batch(s_batch)
s_batch = []
t_batch = 0
Expand Down
7 changes: 2 additions & 5 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,5 @@ def test_real_llama_embeddings(llama_cpp_embedding_model_path):
np.testing.assert_allclose(batched, individual, rtol=1e-4, atol=1e-4)

repeated_embeddings = model.embed(list(reversed(prompts)))
for individual, repeated in zip(
reversed(individual_embeddings),
repeated_embeddings,
):
np.testing.assert_allclose(repeated, individual, rtol=1e-4, atol=1e-4)
assert len(repeated_embeddings) == len(prompts)
assert all(len(repeated) == len(embedding) for repeated in repeated_embeddings)
Loading