Skip to content

Commit c67f9fa

Browse files
committed
Migrate RMM usage to CCCL memory resource design
Remove device_memory_resource base class usage, de-template all resource and adaptor types, replace pointer-based per-device resource APIs with ref-based equivalents, and update all call sites for the new signatures. Part of rapidsai/rmm#2011.
1 parent eb74b0e commit c67f9fa

File tree

19 files changed

+101
-503
lines changed

19 files changed

+101
-503
lines changed

cpp/bench/prims/common/benchmark.hpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -17,7 +17,7 @@
1717
#include <rmm/cuda_stream.hpp>
1818
#include <rmm/cuda_stream_view.hpp>
1919
#include <rmm/device_buffer.hpp>
20-
#include <rmm/mr/device_memory_resource.hpp>
20+
#include <rmm/mr/cuda_memory_resource.hpp>
2121
#include <rmm/mr/per_device_resource.hpp>
2222
#include <rmm/mr/pool_memory_resource.hpp>
2323

@@ -33,26 +33,24 @@ namespace raft::bench {
3333
*/
3434
struct using_pool_memory_res {
3535
private:
36-
rmm::mr::device_memory_resource* orig_res_;
3736
rmm::mr::cuda_memory_resource cuda_res_{};
38-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;
37+
rmm::mr::pool_memory_resource pool_res_;
38+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;
3939

4040
public:
4141
using_pool_memory_res(size_t initial_size, size_t max_size)
42-
: orig_res_(rmm::mr::get_current_device_resource()),
43-
pool_res_(&cuda_res_, initial_size, max_size)
42+
: pool_res_(cuda_res_, initial_size, max_size),
43+
prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_))
4444
{
45-
rmm::mr::set_current_device_resource(&pool_res_);
4645
}
4746

4847
using_pool_memory_res()
49-
: orig_res_(rmm::mr::get_current_device_resource()),
50-
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
48+
: pool_res_(cuda_res_, rmm::percent_of_free_device_memory(50)),
49+
prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_))
5150
{
52-
rmm::mr::set_current_device_resource(&pool_res_);
5351
}
5452

55-
~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); }
53+
~using_pool_memory_res() { rmm::mr::set_current_device_resource_ref(prev_res_); }
5654
};
5755

5856
/**

cpp/bench/prims/matrix/gather.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -13,7 +13,7 @@
1313
#include <raft/util/itertools.hpp>
1414

1515
#include <rmm/device_uvector.hpp>
16-
#include <rmm/mr/device_memory_resource.hpp>
16+
#include <rmm/mr/per_device_resource.hpp>
1717
#include <rmm/mr/pool_memory_resource.hpp>
1818

1919
namespace raft::bench::matrix {
@@ -35,18 +35,17 @@ template <typename T, typename MapT, typename IdxT, bool Conditional = false>
3535
struct Gather : public fixture {
3636
Gather(const GatherParams<IdxT>& p)
3737
: params(p),
38-
old_mr(rmm::mr::get_current_device_resource()),
39-
pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)),
38+
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * (1ULL << 30)),
39+
prev_res_(rmm::mr::set_current_device_resource_ref(pool_mr)),
4040
matrix(this->handle),
4141
map(this->handle),
4242
out(this->handle),
4343
stencil(this->handle),
4444
matrix_h(this->handle)
4545
{
46-
rmm::mr::set_current_device_resource(&pool_mr);
4746
}
4847

49-
~Gather() { rmm::mr::set_current_device_resource(old_mr); }
48+
~Gather() { rmm::mr::set_current_device_resource_ref(prev_res_); }
5049

5150
void allocate_data(const ::benchmark::State& state) override
5251
{
@@ -107,8 +106,8 @@ struct Gather : public fixture {
107106

108107
private:
109108
GatherParams<IdxT> params;
110-
rmm::mr::device_memory_resource* old_mr;
111-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
109+
rmm::mr::pool_memory_resource pool_mr;
110+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;
112111
raft::device_matrix<T, IdxT> matrix, out;
113112
raft::host_matrix<T, IdxT> matrix_h;
114113
raft::device_vector<T, IdxT> stencil;

cpp/bench/prims/random/subsample.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,15 @@ template <typename T>
5050
struct sample : public fixture {
5151
sample(const sample_inputs& p)
5252
: params(p),
53-
old_mr(rmm::mr::get_current_device_resource()),
54-
pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB),
53+
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * GiB),
54+
prev_mr(rmm::mr::set_current_device_resource_ref(pool_mr)),
5555
in(make_device_vector<T, int64_t>(res, p.n_samples)),
5656
out(make_device_vector<T, int64_t>(res, p.n_train))
5757
{
58-
rmm::mr::set_current_device_resource(&pool_mr);
5958
raft::random::RngState r(123456ULL);
6059
}
6160

62-
~sample() { rmm::mr::set_current_device_resource(old_mr); }
61+
~sample() { rmm::mr::set_current_device_resource_ref(prev_mr); }
6362
void run_benchmark(::benchmark::State& state) override
6463
{
6564
std::ostringstream label_stream;
@@ -81,8 +80,8 @@ struct sample : public fixture {
8180
private:
8281
float GiB = 1073741824.0f;
8382
raft::device_resources res;
84-
rmm::mr::device_memory_resource* old_mr;
85-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
83+
rmm::mr::pool_memory_resource pool_mr;
84+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_mr;
8685
sample_inputs params;
8786
raft::device_vector<T, int64_t> out, in;
8887
}; // struct sample

cpp/include/raft/core/device_resources.hpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -26,7 +26,7 @@
2626

2727
#include <rmm/cuda_stream_pool.hpp>
2828
#include <rmm/exec_policy.hpp>
29-
#include <rmm/mr/device_memory_resource.hpp>
29+
#include <rmm/resource_ref.hpp>
3030

3131
#include <cuda_runtime.h>
3232

@@ -51,15 +51,6 @@ namespace raft {
5151
*/
5252
class device_resources : public resources {
5353
public:
54-
device_resources(const device_resources& handle,
55-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource,
56-
std::optional<std::size_t> allocation_limit = std::nullopt)
57-
: resources{handle}
58-
{
59-
// replace the resource factory for the workspace_resources
60-
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
61-
}
62-
6354
device_resources(const device_resources& handle) : resources{handle} {}
6455
device_resources(device_resources&&) = delete;
6556
device_resources& operator=(device_resources&&) = delete;
@@ -70,25 +61,16 @@ class device_resources : public resources {
7061
* @param[in] stream_view the default stream (which has the default per-thread stream if
7162
* unspecified)
7263
* @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified)
73-
* @param[in] workspace_resource an optional resource used by some functions for allocating
74-
* temporary workspaces.
75-
* @param[in] allocation_limit the total amount of memory in bytes available to the temporary
76-
* workspace resources.
7764
*/
7865
device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
79-
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
80-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource = {nullptr},
81-
std::optional<std::size_t> allocation_limit = std::nullopt)
66+
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr})
8267
: resources{}
8368
{
8469
resources::add_resource_factory(std::make_shared<resource::device_id_resource_factory>());
8570
resources::add_resource_factory(
8671
std::make_shared<resource::cuda_stream_resource_factory>(stream_view));
8772
resources::add_resource_factory(
8873
std::make_shared<resource::cuda_stream_pool_resource_factory>(stream_pool));
89-
if (workspace_resource) {
90-
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
91-
}
9274
}
9375

9476
/** Destroys all held-up resources */
@@ -214,7 +196,7 @@ class device_resources : public resources {
214196
return resource::get_subcomm(*this, key);
215197
}
216198

217-
rmm::mr::device_memory_resource* get_workspace_resource() const
199+
rmm::mr::limiting_resource_adaptor* get_workspace_resource() const
218200
{
219201
return resource::get_workspace_resource(*this);
220202
}

cpp/include/raft/core/device_resources_manager.hpp

Lines changed: 13 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -12,6 +12,7 @@
1212
#include <rmm/cuda_stream_pool.hpp>
1313
#include <rmm/mr/cuda_memory_resource.hpp>
1414
#include <rmm/mr/per_device_resource.hpp>
15+
#include <rmm/mr/pool_memory_resource.hpp>
1516

1617
#include <algorithm>
1718
#include <memory>
@@ -114,12 +115,6 @@ struct device_resources_manager {
114115
std::optional<std::size_t> max_mem_pool_size{std::size_t{}};
115116
// Limit on workspace memory for the returned device_resources object
116117
std::optional<std::size_t> workspace_allocation_limit{std::nullopt};
117-
// Optional specification of separate workspace memory resources for each
118-
// device. The integer in each pair indicates the device for this memory
119-
// resource.
120-
std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int>> workspace_mrs{};
121-
122-
auto get_workspace_memory_resource(int device_id) {}
123118
} params_;
124119

125120
// This struct stores the underlying resources to be shared among
@@ -152,35 +147,18 @@ struct device_resources_manager {
152147
}()},
153148
pool_mr_{[&params, this]() {
154149
auto scoped_device = device_setter{device_id_};
155-
auto result =
156-
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr};
150+
auto result = std::optional<rmm::mr::pool_memory_resource>{};
157151
// If max_mem_pool_size is nullopt or non-zero, create a pool memory
158152
// resource
159153
if (params.max_mem_pool_size.value_or(1) != 0) {
160-
auto* upstream =
161-
dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource());
162-
if (upstream != nullptr) {
163-
result =
164-
std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
165-
upstream,
166-
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
167-
params.max_mem_pool_size);
168-
rmm::mr::set_current_device_resource(result.get());
169-
} else {
170-
RAFT_LOG_WARN(
171-
"Pool allocation requested, but other memory resource has already been set and "
172-
"will not be overwritten");
173-
}
154+
auto upstream = rmm::mr::get_current_device_resource_ref();
155+
result.emplace(
156+
upstream,
157+
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
158+
params.max_mem_pool_size);
159+
rmm::mr::set_current_device_resource_ref(*result);
174160
}
175161
return result;
176-
}()},
177-
workspace_mr_{[&params, this]() {
178-
auto result = std::shared_ptr<rmm::mr::device_memory_resource>{nullptr};
179-
auto iter = std::find_if(std::begin(params.workspace_mrs),
180-
std::end(params.workspace_mrs),
181-
[this](auto&& pair) { return pair.second == device_id_; });
182-
if (iter != std::end(params.workspace_mrs)) { result = iter->first; }
183-
return result;
184162
}()}
185163
{
186164
}
@@ -216,27 +194,14 @@ struct device_resources_manager {
216194
if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; }
217195
return result;
218196
}
219-
// Return a (possibly null) shared_ptr to the pool memory resource
220-
// created for this device by the manager
221-
[[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; }
222-
// Return the RAFT workspace allocation limit that will be used by
223-
// `device_resources` returned from this manager
224-
[[nodiscard]] auto get_workspace_allocation_limit() const
225-
{
226-
return workspace_allocation_limit_;
227-
}
228-
// Return a (possibly null) shared_ptr to the memory resource that will
229-
// be used for workspace allocations by `device_resources` returned from
230-
// this manager
231-
[[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; }
197+
// Return the pool memory resource created for this device by the manager (if any)
198+
[[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; }
232199

233200
private:
234201
int device_id_;
235202
std::unique_ptr<rmm::cuda_stream_pool> streams_;
236203
std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_;
237-
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr_;
238-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_mr_;
239-
std::optional<std::size_t> workspace_allocation_limit_{std::nullopt};
204+
std::optional<rmm::mr::pool_memory_resource> pool_mr_;
240205
};
241206

242207
// Mutex used to lock access to shared data until after the first
@@ -290,10 +255,7 @@ struct device_resources_manager {
290255
auto scoped_device = device_setter(device_id);
291256
// Build the device_resources object for this thread out of shared
292257
// components
293-
thread_resources[device_id].emplace(component_iter->get_stream(),
294-
component_iter->get_pool(),
295-
component_iter->get_workspace_memory_resource(),
296-
component_iter->get_workspace_allocation_limit());
258+
thread_resources[device_id].emplace(component_iter->get_stream(), component_iter->get_pool());
297259
}
298260

299261
return thread_resources[device_id].value();
@@ -373,27 +335,6 @@ struct device_resources_manager {
373335
}
374336
}
375337

376-
// Thread-safe setter for workspace memory resources
377-
void set_workspace_memory_resource_(std::shared_ptr<rmm::mr::device_memory_resource> mr,
378-
int device_id)
379-
{
380-
auto lock = get_lock();
381-
if (params_finalized_) {
382-
RAFT_LOG_WARN(
383-
"Attempted to set device_resources_manager properties after resources have already been "
384-
"retrieved");
385-
} else {
386-
auto iter = std::find_if(std::begin(params_.workspace_mrs),
387-
std::end(params_.workspace_mrs),
388-
[device_id](auto&& pair) { return pair.second == device_id; });
389-
if (iter != std::end(params_.workspace_mrs)) {
390-
iter->first = mr;
391-
} else {
392-
params_.workspace_mrs.emplace_back(mr, device_id);
393-
}
394-
}
395-
}
396-
397338
// Retrieve the instance of this singleton
398339
static auto& get_manager()
399340
{
@@ -543,24 +484,5 @@ struct device_resources_manager {
543484
set_init_mem_pool_size(init_mem);
544485
set_max_mem_pool_size(max_mem);
545486
}
546-
547-
/**
548-
* @brief Set the workspace memory resource to be used on a specific device
549-
*
550-
* RAFT device_resources objects can be built with a separate memory
551-
* resource for allocating temporary workspaces. If a (non-nullptr) memory
552-
* resource is provided by this setter, it will be used as the
553-
* workspace memory resource for all `device_resources` returned for the
554-
* indicated device.
555-
*
556-
* If called after the first call to
557-
* `raft::device_resources_manager::get_device_resources`, no change will be made,
558-
* and a warning will be emitted.
559-
*/
560-
static void set_workspace_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
561-
int device_id = device_setter::get_current_device())
562-
{
563-
get_manager().set_workspace_memory_resource_(mr, device_id);
564-
}
565487
};
566488
} // namespace raft

cpp/include/raft/core/device_resources_snmg.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <raft/core/resource/resource_types.hpp>
1111

1212
#include <rmm/cuda_device.hpp>
13-
#include <rmm/mr/device_memory_resource.hpp>
1413
#include <rmm/mr/per_device_resource.hpp>
1514
#include <rmm/mr/pool_memory_resource.hpp>
1615

@@ -105,10 +104,9 @@ class device_resources_snmg : public device_resources {
105104
int device_id = raft::resource::get_device_id(dev_res);
106105
pool_device_ids_.push_back(device_id);
107106

108-
per_device_pools_.push_back(
109-
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
110-
rmm::mr::get_current_device_resource_ref(),
111-
rmm::percent_of_free_device_memory(percent_of_free_memory)));
107+
per_device_pools_.push_back(std::make_unique<rmm::mr::pool_memory_resource>(
108+
rmm::mr::get_current_device_resource_ref(),
109+
rmm::percent_of_free_device_memory(percent_of_free_memory)));
112110
rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id{device_id},
113111
*per_device_pools_.back());
114112
}
@@ -151,8 +149,7 @@ class device_resources_snmg : public device_resources {
151149
}
152150
}
153151
int main_gpu_id_;
154-
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
155-
per_device_pools_;
152+
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource>> per_device_pools_;
156153
std::vector<int> pool_device_ids_;
157154
}; // class device_resources_snmg
158155

0 commit comments

Comments
 (0)