@@ -199,6 +199,8 @@ def _warn_deprecated(symbol: str, hint: str) -> None:
199199llama_token_p = ctypes .POINTER (llama_token )
200200# typedef int32_t llama_seq_id;
201201llama_seq_id = ctypes .c_int32
202+ # typedef uint32_t llama_state_seq_flags;
203+ llama_state_seq_flags = ctypes .c_uint32
202204
203205
204206# enum llama_vocab_type {
@@ -2835,6 +2837,92 @@ def llama_state_seq_load_file(
28352837) -> int : ...
28362838
28372839
2840+ # for backwards-compat
2841+ # define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
2842+ LLAMA_STATE_SEQ_FLAGS_SWA_ONLY = 1
2843+
2844+ # work only with partial states, such as SWA KV cache or recurrent cache
2845+ # (e.g. Mamba)
2846+ # define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
2847+ LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY = 1
2848+
2849+ # keeps the tensor data on device buffers
2850+ # (i.e. not accessible in host memory, but faster save/load)
2851+ # define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2
2852+ LLAMA_STATE_SEQ_FLAGS_ON_DEVICE = 2
2853+
2854+
2855+ # LLAMA_API size_t llama_state_seq_get_size_ext(
2856+ # struct llama_context * ctx,
2857+ # llama_seq_id seq_id,
2858+ # llama_state_seq_flags flags);
2859+ @ctypes_function (
2860+ "llama_state_seq_get_size_ext" ,
2861+ [llama_context_p_ctypes , llama_seq_id , llama_state_seq_flags ],
2862+ ctypes .c_size_t ,
2863+ )
2864+ def llama_state_seq_get_size_ext (
2865+ ctx : llama_context_p ,
2866+ seq_id : llama_seq_id ,
2867+ flags : llama_state_seq_flags ,
2868+ / ,
2869+ ) -> int : ...
2870+
2871+
2872+ # LLAMA_API size_t llama_state_seq_get_data_ext(
2873+ # struct llama_context * ctx,
2874+ # uint8_t * dst,
2875+ # size_t size,
2876+ # llama_seq_id seq_id,
2877+ # llama_state_seq_flags flags);
2878+ @ctypes_function (
2879+ "llama_state_seq_get_data_ext" ,
2880+ [
2881+ llama_context_p_ctypes ,
2882+ ctypes .POINTER (ctypes .c_uint8 ),
2883+ ctypes .c_size_t ,
2884+ llama_seq_id ,
2885+ llama_state_seq_flags ,
2886+ ],
2887+ ctypes .c_size_t ,
2888+ )
2889+ def llama_state_seq_get_data_ext (
2890+ ctx : llama_context_p ,
2891+ dst : CtypesArray [ctypes .c_uint8 ],
2892+ size : Union [ctypes .c_size_t , int ],
2893+ seq_id : llama_seq_id ,
2894+ flags : llama_state_seq_flags ,
2895+ / ,
2896+ ) -> int : ...
2897+
2898+
2899+ # LLAMA_API size_t llama_state_seq_set_data_ext(
2900+ # struct llama_context * ctx,
2901+ # const uint8_t * src,
2902+ # size_t size,
2903+ # llama_seq_id dest_seq_id,
2904+ # llama_state_seq_flags flags);
2905+ @ctypes_function (
2906+ "llama_state_seq_set_data_ext" ,
2907+ [
2908+ llama_context_p_ctypes ,
2909+ ctypes .POINTER (ctypes .c_uint8 ),
2910+ ctypes .c_size_t ,
2911+ llama_seq_id ,
2912+ llama_state_seq_flags ,
2913+ ],
2914+ ctypes .c_size_t ,
2915+ )
2916+ def llama_state_seq_set_data_ext (
2917+ ctx : llama_context_p ,
2918+ src : CtypesArray [ctypes .c_uint8 ],
2919+ size : Union [ctypes .c_size_t , int ],
2920+ dest_seq_id : llama_seq_id ,
2921+ flags : llama_state_seq_flags ,
2922+ / ,
2923+ ) -> int : ...
2924+
2925+
28382926# //
28392927# // Decoding
28402928# //
0 commit comments