|
1 | 1 | /* |
2 | | - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | */ |
5 | 5 |
|
|
12 | 12 | #include <rmm/cuda_stream_pool.hpp> |
13 | 13 | #include <rmm/mr/cuda_memory_resource.hpp> |
14 | 14 | #include <rmm/mr/per_device_resource.hpp> |
| 15 | +#include <rmm/mr/pool_memory_resource.hpp> |
15 | 16 |
|
16 | 17 | #include <algorithm> |
17 | 18 | #include <memory> |
@@ -114,12 +115,6 @@ struct device_resources_manager { |
114 | 115 | std::optional<std::size_t> max_mem_pool_size{std::size_t{}}; |
115 | 116 | // Limit on workspace memory for the returned device_resources object |
116 | 117 | std::optional<std::size_t> workspace_allocation_limit{std::nullopt}; |
117 | | - // Optional specification of separate workspace memory resources for each |
118 | | - // device. The integer in each pair indicates the device for this memory |
119 | | - // resource. |
120 | | - std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int>> workspace_mrs{}; |
121 | | - |
122 | | - auto get_workspace_memory_resource(int device_id) {} |
123 | 118 | } params_; |
124 | 119 |
|
125 | 120 | // This struct stores the underlying resources to be shared among |
@@ -152,35 +147,18 @@ struct device_resources_manager { |
152 | 147 | }()}, |
153 | 148 | pool_mr_{[¶ms, this]() { |
154 | 149 | auto scoped_device = device_setter{device_id_}; |
155 | | - auto result = |
156 | | - std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr}; |
| 150 | + auto result = std::optional<rmm::mr::pool_memory_resource>{}; |
157 | 151 | // If max_mem_pool_size is nullopt or non-zero, create a pool memory |
158 | 152 | // resource |
159 | 153 | if (params.max_mem_pool_size.value_or(1) != 0) { |
160 | | - auto* upstream = |
161 | | - dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource()); |
162 | | - if (upstream != nullptr) { |
163 | | - result = |
164 | | - std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>( |
165 | | - upstream, |
166 | | - params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), |
167 | | - params.max_mem_pool_size); |
168 | | - rmm::mr::set_current_device_resource(result.get()); |
169 | | - } else { |
170 | | - RAFT_LOG_WARN( |
171 | | - "Pool allocation requested, but other memory resource has already been set and " |
172 | | - "will not be overwritten"); |
173 | | - } |
| 154 | + auto upstream = rmm::mr::get_current_device_resource_ref(); |
| 155 | + result.emplace( |
| 156 | + upstream, |
| 157 | + params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), |
| 158 | + params.max_mem_pool_size); |
| 159 | + rmm::mr::set_current_device_resource_ref(*result); |
174 | 160 | } |
175 | 161 | return result; |
176 | | - }()}, |
177 | | - workspace_mr_{[¶ms, this]() { |
178 | | - auto result = std::shared_ptr<rmm::mr::device_memory_resource>{nullptr}; |
179 | | - auto iter = std::find_if(std::begin(params.workspace_mrs), |
180 | | - std::end(params.workspace_mrs), |
181 | | - [this](auto&& pair) { return pair.second == device_id_; }); |
182 | | - if (iter != std::end(params.workspace_mrs)) { result = iter->first; } |
183 | | - return result; |
184 | 162 | }()} |
185 | 163 | { |
186 | 164 | } |
@@ -216,27 +194,14 @@ struct device_resources_manager { |
216 | 194 | if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; } |
217 | 195 | return result; |
218 | 196 | } |
219 | | - // Return a (possibly null) shared_ptr to the pool memory resource |
220 | | - // created for this device by the manager |
221 | | - [[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; } |
222 | | - // Return the RAFT workspace allocation limit that will be used by |
223 | | - // `device_resources` returned from this manager |
224 | | - [[nodiscard]] auto get_workspace_allocation_limit() const |
225 | | - { |
226 | | - return workspace_allocation_limit_; |
227 | | - } |
228 | | - // Return a (possibly null) shared_ptr to the memory resource that will |
229 | | - // be used for workspace allocations by `device_resources` returned from |
230 | | - // this manager |
231 | | - [[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; } |
| 197 | + // Return the pool memory resource created for this device by the manager (if any) |
| 198 | + [[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; } |
232 | 199 |
|
233 | 200 | private: |
234 | 201 | int device_id_; |
235 | 202 | std::unique_ptr<rmm::cuda_stream_pool> streams_; |
236 | 203 | std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_; |
237 | | - std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr_; |
238 | | - std::shared_ptr<rmm::mr::device_memory_resource> workspace_mr_; |
239 | | - std::optional<std::size_t> workspace_allocation_limit_{std::nullopt}; |
| 204 | + std::optional<rmm::mr::pool_memory_resource> pool_mr_; |
240 | 205 | }; |
241 | 206 |
|
242 | 207 | // Mutex used to lock access to shared data until after the first |
@@ -290,10 +255,7 @@ struct device_resources_manager { |
290 | 255 | auto scoped_device = device_setter(device_id); |
291 | 256 | // Build the device_resources object for this thread out of shared |
292 | 257 | // components |
293 | | - thread_resources[device_id].emplace(component_iter->get_stream(), |
294 | | - component_iter->get_pool(), |
295 | | - component_iter->get_workspace_memory_resource(), |
296 | | - component_iter->get_workspace_allocation_limit()); |
| 258 | + thread_resources[device_id].emplace(component_iter->get_stream(), component_iter->get_pool()); |
297 | 259 | } |
298 | 260 |
|
299 | 261 | return thread_resources[device_id].value(); |
@@ -373,27 +335,6 @@ struct device_resources_manager { |
373 | 335 | } |
374 | 336 | } |
375 | 337 |
|
376 | | - // Thread-safe setter for workspace memory resources |
377 | | - void set_workspace_memory_resource_(std::shared_ptr<rmm::mr::device_memory_resource> mr, |
378 | | - int device_id) |
379 | | - { |
380 | | - auto lock = get_lock(); |
381 | | - if (params_finalized_) { |
382 | | - RAFT_LOG_WARN( |
383 | | - "Attempted to set device_resources_manager properties after resources have already been " |
384 | | - "retrieved"); |
385 | | - } else { |
386 | | - auto iter = std::find_if(std::begin(params_.workspace_mrs), |
387 | | - std::end(params_.workspace_mrs), |
388 | | - [device_id](auto&& pair) { return pair.second == device_id; }); |
389 | | - if (iter != std::end(params_.workspace_mrs)) { |
390 | | - iter->first = mr; |
391 | | - } else { |
392 | | - params_.workspace_mrs.emplace_back(mr, device_id); |
393 | | - } |
394 | | - } |
395 | | - } |
396 | | - |
397 | 338 | // Retrieve the instance of this singleton |
398 | 339 | static auto& get_manager() |
399 | 340 | { |
@@ -543,24 +484,5 @@ struct device_resources_manager { |
543 | 484 | set_init_mem_pool_size(init_mem); |
544 | 485 | set_max_mem_pool_size(max_mem); |
545 | 486 | } |
546 | | - |
547 | | - /** |
548 | | - * @brief Set the workspace memory resource to be used on a specific device |
549 | | - * |
550 | | - * RAFT device_resources objects can be built with a separate memory |
551 | | - * resource for allocating temporary workspaces. If a (non-nullptr) memory |
552 | | - * resource is provided by this setter, it will be used as the |
553 | | - * workspace memory resource for all `device_resources` returned for the |
554 | | - * indicated device. |
555 | | - * |
556 | | - * If called after the first call to |
557 | | - * `raft::device_resources_manager::get_device_resources`, no change will be made, |
558 | | - * and a warning will be emitted. |
559 | | - */ |
560 | | - static void set_workspace_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr, |
561 | | - int device_id = device_setter::get_current_device()) |
562 | | - { |
563 | | - get_manager().set_workspace_memory_resource_(mr, device_id); |
564 | | - } |
565 | 487 | }; |
566 | 488 | } // namespace raft |
0 commit comments