@@ -75,6 +75,7 @@ def __init__(
7575 n_ctx : int = 512 ,
7676 n_batch : int = 512 ,
7777 n_ubatch : int = 512 ,
78+ n_seq_max : Optional [int ] = None ,
7879 n_threads : Optional [int ] = None ,
7980 n_threads_batch : Optional [int ] = None ,
8081 rope_scaling_type : Optional [
@@ -160,6 +161,9 @@ def __init__(
160161 n_ctx: Text context, 0 = from model
161162 n_batch: Prompt processing maximum batch size
162163 n_ubatch: Physical batch size
164+ n_seq_max: Maximum number of sequences. If None, embedding contexts
165+ use min(n_batch, llama_max_parallel_sequences()) and
166+ non-embedding contexts use the llama.cpp default.
163167 n_threads: Number of threads to use for generation
164168 n_threads_batch: Number of threads to use for batch processing
165169 rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -397,6 +401,21 @@ def __init__(
397401 self .context_params .n_batch = self .n_batch
398402 self .context_params .n_ubatch = min (self .n_batch , n_ubatch )
399403
404+ if n_seq_max is not None :
405+ n_seq_max_limit = llama_cpp .llama_max_parallel_sequences ()
406+ if n_seq_max <= 0 :
407+ raise ValueError ("n_seq_max must be greater than 0" )
408+ if n_seq_max > n_seq_max_limit :
409+ raise ValueError (
410+ f"n_seq_max must be less than or equal to { n_seq_max_limit } "
411+ )
412+ self .context_params .n_seq_max = n_seq_max
413+ elif embedding :
414+ self .context_params .n_seq_max = min (
415+ self .n_batch ,
416+ llama_cpp .llama_max_parallel_sequences (),
417+ )
418+
400419 self ._ctx = self ._stack .enter_context (
401420 contextlib .closing (
402421 internals .LlamaContext (
@@ -1030,6 +1049,7 @@ def embed(
10301049 """
10311050 n_embd = self .n_embd ()
10321051 n_batch = self .n_batch
1052+ n_seq_max = self .context_params .n_seq_max
10331053
10341054 # get pooling information
10351055 pooling_type = self .pooling_type ()
@@ -1104,7 +1124,7 @@ def decode_batch(seq_sizes: List[int]):
11041124 )
11051125
11061126 # time to eval batch
1107- if t_batch + n_tokens > n_batch :
1127+ if t_batch + n_tokens > n_batch or p_batch >= n_seq_max :
11081128 decode_batch (s_batch )
11091129 s_batch = []
11101130 t_batch = 0
@@ -2099,6 +2119,7 @@ def __getstate__(self):
20992119 n_ctx = self .context_params .n_ctx ,
21002120 n_batch = self .n_batch ,
21012121 n_ubatch = self .context_params .n_ubatch ,
2122+ n_seq_max = self .context_params .n_seq_max ,
21022123 n_threads = self .context_params .n_threads ,
21032124 n_threads_batch = self .context_params .n_threads_batch ,
21042125 rope_scaling_type = self .context_params .rope_scaling_type ,
0 commit comments