@@ -137,9 +137,21 @@ class SparsePreprocessCSR
137137 auto bm25_vals = raft::make_device_vector<Type_f, int64_t >(handle, int (coo_a.nnz ));
138138 raft::util::calc_tfidf_bm25<Index_, Type_f>(handle, csr_matrix.view (), bm25_vals.view ());
139139 if (coo_on) {
140- raft::sparse::matrix::encode_bm25<float , int >(handle, coo_a_matrix, result.view ());
140+ raft::execute_with_dry_run_check (
141+ handle,
142+ [&](raft::resources const & h) {
143+ raft::sparse::matrix::encode_bm25<float , int >(h, coo_a_matrix, result.view ());
144+ },
145+ raft::alloc_behavior::DATA_DRIVEN,
146+ sizeof (float ) * coo_a.nnz );
141147 } else {
142- raft::sparse::matrix::encode_bm25<float , int >(handle, csr_matrix, result.view ());
148+ raft::execute_with_dry_run_check (
149+ handle,
150+ [&](raft::resources const & h) {
151+ raft::sparse::matrix::encode_bm25<float , int >(h, csr_matrix, result.view ());
152+ },
153+ raft::alloc_behavior::DATA_DRIVEN,
154+ sizeof (float ) * coo_a.nnz );
143155 }
144156 ASSERT_TRUE (raft::devArrMatch<Type_f>(bm25_vals.data_handle (),
145157 result.data_handle (),
@@ -151,9 +163,21 @@ class SparsePreprocessCSR
151163 raft::util::calc_tfidf_bm25<Index_, Type_f>(
152164 handle, csr_matrix.view (), tfidf_vals.view (), true );
153165 if (coo_on) {
154- raft::sparse::matrix::encode_tfidf<float , int >(handle, coo_a_matrix, result.view ());
166+ raft::execute_with_dry_run_check (
167+ handle,
168+ [&](raft::resources const & h) {
169+ raft::sparse::matrix::encode_tfidf<float , int >(h, coo_a_matrix, result.view ());
170+ },
171+ raft::alloc_behavior::ARGUMENT_DRIVEN,
172+ 1 );
155173 } else {
156- raft::sparse::matrix::encode_tfidf<float , int >(handle, csr_matrix, result.view ());
174+ raft::execute_with_dry_run_check (
175+ handle,
176+ [&](raft::resources const & h) {
177+ raft::sparse::matrix::encode_tfidf<float , int >(h, csr_matrix, result.view ());
178+ },
179+ raft::alloc_behavior::ARGUMENT_DRIVEN,
180+ 1 );
157181 }
158182 ASSERT_TRUE (raft::devArrMatch<Type_f>(tfidf_vals.data_handle (),
159183 result.data_handle (),
0 commit comments