Skip to content

Commit fee4b62

Browse files
committed
Expand test coverage Part 2
1 parent d4ff16e commit fee4b62

File tree

12 files changed

+185
-87
lines changed

12 files changed

+185
-87
lines changed

cpp/tests/sparse/csr_transpose.cu

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -94,17 +94,23 @@ class CSRTransposeTest : public ::testing::TestWithParam<CSRTransposeInputs<valu
9494

9595
make_data();
9696

97-
raft::sparse::linalg::csr_transpose(handle,
98-
indptr.data(),
99-
indices.data(),
100-
data.data(),
101-
out_indptr.data(),
102-
out_indices.data(),
103-
out_data.data(),
104-
params.nrows,
105-
params.ncols,
106-
params.nnz,
107-
stream);
97+
raft::execute_with_dry_run_check(
98+
handle,
99+
[&](raft::resources const& h) {
100+
raft::sparse::linalg::csr_transpose(h,
101+
indptr.data(),
102+
indices.data(),
103+
data.data(),
104+
out_indptr.data(),
105+
out_indices.data(),
106+
out_data.data(),
107+
params.nrows,
108+
params.ncols,
109+
params.nnz,
110+
resource::get_cuda_stream(h));
111+
},
112+
raft::alloc_behavior::ARGUMENT_DRIVEN,
113+
1);
108114

109115
resource::sync_stream(handle, stream);
110116
}

cpp/tests/sparse/masked_matmul.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -338,12 +338,18 @@ class MaskedMatmulTest
338338

339339
if constexpr (bits_layout == BitsLayout::Bitmap) {
340340
auto mask = raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
341-
raft::sparse::linalg::masked_matmul(handle, A, B, mask, C);
341+
raft::execute_with_dry_run_check(
342+
handle,
343+
[&](raft::resources const& h) { raft::sparse::linalg::masked_matmul(h, A, B, mask, C); },
344+
raft::alloc_behavior::ARGUMENT_DRIVEN,
345+
c_data_d.size() * sizeof(output_t));
342346
} else if constexpr (bits_layout == BitsLayout::Bitset) {
343347
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
344-
raft::sparse::linalg::masked_matmul(handle, A, B, mask, C);
345-
} else {
346-
GTEST_SKIP() << "Unsupported BitsLayout!";
348+
raft::execute_with_dry_run_check(
349+
handle,
350+
[&](raft::resources const& h) { raft::sparse::linalg::masked_matmul(h, A, B, mask, C); },
351+
raft::alloc_behavior::ARGUMENT_DRIVEN,
352+
c_data_d.size() * sizeof(output_t));
347353
}
348354

349355
resource::sync_stream(handle);

cpp/tests/sparse/norm.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -52,7 +52,12 @@ class CSRRowNormTest : public ::testing::TestWithParam<CSRRowNormInputs<Type_f,
5252
raft::update_device(data.data(), params.data.data(), nnz, stream);
5353
raft::update_device(verify.data(), params.verify.data(), n_rows, stream);
5454

55-
linalg::rowNormCsr(handle, indptr.data(), data.data(), nnz, n_rows, result.data(), params.norm);
55+
raft::execute_with_dry_run_check(
56+
handle,
57+
[&](raft::resources const& h) {
58+
linalg::rowNormCsr(h, indptr.data(), data.data(), nnz, n_rows, result.data(), params.norm);
59+
},
60+
raft::alloc_behavior::NO_ALLOCATIONS);
5661
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
5762

5863
ASSERT_TRUE(

cpp/tests/sparse/preprocess.cu

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,21 @@ class SparsePreprocessCSR
137137
auto bm25_vals = raft::make_device_vector<Type_f, int64_t>(handle, int(coo_a.nnz));
138138
raft::util::calc_tfidf_bm25<Index_, Type_f>(handle, csr_matrix.view(), bm25_vals.view());
139139
if (coo_on) {
140-
raft::sparse::matrix::encode_bm25<float, int>(handle, coo_a_matrix, result.view());
140+
raft::execute_with_dry_run_check(
141+
handle,
142+
[&](raft::resources const& h) {
143+
raft::sparse::matrix::encode_bm25<float, int>(h, coo_a_matrix, result.view());
144+
},
145+
raft::alloc_behavior::DATA_DRIVEN,
146+
sizeof(float) * coo_a.nnz);
141147
} else {
142-
raft::sparse::matrix::encode_bm25<float, int>(handle, csr_matrix, result.view());
148+
raft::execute_with_dry_run_check(
149+
handle,
150+
[&](raft::resources const& h) {
151+
raft::sparse::matrix::encode_bm25<float, int>(h, csr_matrix, result.view());
152+
},
153+
raft::alloc_behavior::DATA_DRIVEN,
154+
sizeof(float) * coo_a.nnz);
143155
}
144156
ASSERT_TRUE(raft::devArrMatch<Type_f>(bm25_vals.data_handle(),
145157
result.data_handle(),
@@ -151,9 +163,21 @@ class SparsePreprocessCSR
151163
raft::util::calc_tfidf_bm25<Index_, Type_f>(
152164
handle, csr_matrix.view(), tfidf_vals.view(), true);
153165
if (coo_on) {
154-
raft::sparse::matrix::encode_tfidf<float, int>(handle, coo_a_matrix, result.view());
166+
raft::execute_with_dry_run_check(
167+
handle,
168+
[&](raft::resources const& h) {
169+
raft::sparse::matrix::encode_tfidf<float, int>(h, coo_a_matrix, result.view());
170+
},
171+
raft::alloc_behavior::ARGUMENT_DRIVEN,
172+
1);
155173
} else {
156-
raft::sparse::matrix::encode_tfidf<float, int>(handle, csr_matrix, result.view());
174+
raft::execute_with_dry_run_check(
175+
handle,
176+
[&](raft::resources const& h) {
177+
raft::sparse::matrix::encode_tfidf<float, int>(h, csr_matrix, result.view());
178+
},
179+
raft::alloc_behavior::ARGUMENT_DRIVEN,
180+
1);
157181
}
158182
ASSERT_TRUE(raft::devArrMatch<Type_f>(tfidf_vals.data_handle(),
159183
result.data_handle(),

cpp/tests/sparse/reduce.cu

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -63,14 +63,21 @@ class SparseReduceTest : public ::testing::TestWithParam<SparseReduceInputs<valu
6363
raft::update_device(out_vals.data(), params.out_vals.data(), params.out_vals.size(), stream);
6464

6565
raft::sparse::COO<value_t, value_idx, value_idx> out(stream);
66-
raft::sparse::op::max_duplicates(handle,
67-
out,
68-
in_rows.data(),
69-
in_cols.data(),
70-
in_vals.data(),
71-
(value_idx)params.in_rows.size(),
72-
(value_idx)params.m,
73-
(value_idx)params.n);
66+
auto min_alloc = (2 * sizeof(value_idx) + sizeof(value_t)) * params.out_vals.size();
67+
raft::execute_with_dry_run_check(
68+
handle,
69+
[&](raft::resources const& h) {
70+
raft::sparse::op::max_duplicates(h,
71+
out,
72+
in_rows.data(),
73+
in_cols.data(),
74+
in_vals.data(),
75+
(value_idx)params.in_rows.size(),
76+
(value_idx)params.m,
77+
(value_idx)params.n);
78+
},
79+
raft::alloc_behavior::DATA_DRIVEN,
80+
min_alloc);
7481
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
7582
ASSERT_TRUE(raft::devArrMatch<value_idx>(
7683
out_rows.data(), out.rows(), out.nnz, raft::Compare<value_idx>()));

cpp/tests/sparse/sddmm.cu

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -313,14 +313,20 @@ class SDDMMTest : public ::testing::TestWithParam<SDDMMInputs<ValueType, IndexTy
313313
auto op_b = params.transpose_b ? raft::linalg::Operation::TRANSPOSE
314314
: raft::linalg::Operation::NON_TRANSPOSE;
315315

316-
raft::sparse::linalg::sddmm(handle,
317-
a,
318-
b,
319-
c,
320-
op_a,
321-
op_b,
322-
raft::make_host_scalar_view<OutputType>(&params.alpha),
323-
raft::make_host_scalar_view<OutputType>(&params.beta));
316+
raft::execute_with_dry_run_check(
317+
handle,
318+
[&](raft::resources const& h) {
319+
raft::sparse::linalg::sddmm(h,
320+
a,
321+
b,
322+
c,
323+
op_a,
324+
op_b,
325+
raft::make_host_scalar_view<OutputType>(&params.alpha),
326+
raft::make_host_scalar_view<OutputType>(&params.beta));
327+
},
328+
raft::alloc_behavior::ARGUMENT_DRIVEN,
329+
1);
324330

325331
resource::sync_stream(handle);
326332

cpp/tests/sparse/select_k_csr.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -287,8 +287,14 @@ class SelectKCsrTest : public ::testing::TestWithParam<SelectKCsrInputs<index_t>
287287
auto out_idx = raft::make_device_matrix_view<index_t, index_t, raft::row_major>(
288288
dst_indices_d.data(), params.n_rows, params.top_k);
289289

290-
raft::sparse::matrix::select_k(
291-
handle, in_val, in_idx, out_val, out_idx, params.select_min, true);
290+
raft::execute_with_dry_run_check(
291+
handle,
292+
[&](raft::resources const& h) {
293+
raft::sparse::matrix::select_k(
294+
h, in_val, in_idx, out_val, out_idx, params.select_min, true);
295+
},
296+
raft::alloc_behavior::ARGUMENT_DRIVEN,
297+
1);
292298

293299
ASSERT_TRUE(raft::devArrMatch<index_t>(dst_indices_expected_d.data(),
294300
out_idx.data_handle(),

cpp/tests/sparse/solver/lanczos.cu

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,19 @@ class rmat_lanczos_tests
192192
auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
193193
const_cast<ValueType*>(symmetric_coo.vals()), csr_structure);
194194

195-
std::get<0>(stats) = raft::sparse::solver::lanczos_compute_eigenpairs<IndexType, ValueType>(
195+
raft::execute_with_dry_run_check(
196196
handle,
197-
config,
198-
csr_matrix,
199-
std::make_optional(v0.view()),
200-
eigenvalues.view(),
201-
eigenvectors.view());
197+
[&](raft::resources const& h) {
198+
std::get<0>(stats) = raft::sparse::solver::lanczos_compute_eigenpairs<IndexType, ValueType>(
199+
h,
200+
config,
201+
csr_matrix,
202+
std::make_optional(v0.view()),
203+
eigenvalues.view(),
204+
eigenvectors.view());
205+
},
206+
raft::alloc_behavior::ARGUMENT_DRIVEN,
207+
sizeof(ValueType) * symmetric_coo.n_rows * config.ncv);
202208

203209
ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
204210
expected_eigenvalues.data_handle(),
@@ -340,13 +346,19 @@ class lanczos_tests : public ::testing::TestWithParam<lanczos_inputs<IndexType,
340346
auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
341347
const_cast<ValueType*>(vals.data_handle()), csr_structure);
342348

343-
std::get<0>(stats) = raft::sparse::solver::lanczos_compute_eigenpairs<IndexType, ValueType>(
349+
raft::execute_with_dry_run_check(
344350
handle,
345-
config,
346-
csr_matrix,
347-
std::make_optional(v0.view()),
348-
eigenvalues.view(),
349-
eigenvectors.view());
351+
[&](raft::resources const& h) {
352+
std::get<0>(stats) = raft::sparse::solver::lanczos_compute_eigenpairs<IndexType, ValueType>(
353+
h,
354+
config,
355+
csr_matrix,
356+
std::make_optional(v0.view()),
357+
eigenvalues.view(),
358+
eigenvectors.view());
359+
},
360+
raft::alloc_behavior::ARGUMENT_DRIVEN,
361+
sizeof(ValueType) * n * config.ncv);
350362

351363
ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
352364
expected_eigenvalues.data_handle(),

cpp/tests/sparse/spmm.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -189,8 +189,13 @@ class SpmmTest : public ::testing::TestWithParam<SpmmInputs<T>> {
189189
ldz,
190190
params.row_major);
191191

192-
spmm(
193-
handle, params.trans_x, params.trans_y, &alpha, X_csr, y_stride_view, &beta, z_stride_view);
192+
raft::execute_with_dry_run_check(
193+
handle,
194+
[&](raft::resources const& h) {
195+
spmm(h, params.trans_x, params.trans_y, &alpha, X_csr, y_stride_view, &beta, z_stride_view);
196+
},
197+
raft::alloc_behavior::ARGUMENT_DRIVEN,
198+
z_size * sizeof(T));
194199

195200
resource::sync_stream(handle, stream);
196201

cpp/tests/sparse/symmetrize.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -94,8 +94,14 @@ class SparseSymmetrizeTest
9494

9595
raft::sparse::COO<value_t, value_idx, nnz_t> out(stream);
9696

97-
raft::sparse::linalg::symmetrize(
98-
handle, coo_rows.data(), indices.data(), data.data(), m, n, coo_rows.size(), out);
97+
raft::execute_with_dry_run_check(
98+
handle,
99+
[&](raft::resources const& h) {
100+
raft::sparse::linalg::symmetrize(
101+
h, coo_rows.data(), indices.data(), data.data(), m, n, coo_rows.size(), out);
102+
},
103+
raft::alloc_behavior::DATA_DRIVEN,
104+
nnz * 2 * (2 * sizeof(value_idx) + sizeof(value_t)));
99105

100106
rmm::device_scalar<value_idx> sum(stream);
101107
sum.set_value_to_zero_async(stream);

0 commit comments

Comments
 (0)