Skip to content

Commit 69543a1

Browse files
committed
Clarify that all workspace resources are actually counted independently despite the hierarchy
1 parent b682d46 commit 69543a1

File tree

5 files changed

+207
-8
lines changed

5 files changed

+207
-8
lines changed

cpp/include/raft/util/dry_run_memory_resource.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,20 @@ class dry_run_resources : public resources {
164164

165165
void init()
166166
{
167-
// Force-initialize all affected resources (lazy creation).
167+
// Independent-counting invariant
168+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
169+
// 1. Force-initialize all lazily-created resources (workspace, large workspace,
170+
// pinned, managed) so that their factories resolve against the *original*
171+
// global device MR, not a tracking wrapper we install later.
172+
// 2. Capture every upstream ref while it still points to the original resource.
173+
// 3. Snapshot the resource map to keep the originals alive.
174+
// 4. Only *then* replace the global device resource with the tracking bridge.
175+
// 5. Wrap each captured upstream with a separate dry_run_resource adaptor.
176+
//
177+
// Because step 2 happens before step 4, workspace/lws allocations flow through
178+
// their own adaptor directly to old_device_mr_, bypassing the device bridge.
179+
// Each allocation is therefore counted in exactly one category, and
180+
// memory_stats::total() returns an accurate, non-overlapping sum.
168181
auto* ws = resource::get_workspace_resource(*this);
169182
auto ws_free = resource::get_workspace_free_bytes(*this);
170183
auto ws_upstream = ws->get_upstream_resource();

cpp/include/raft/util/memory_stats_resources.hpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,15 @@ struct memory_stats {
4343
std::size_t host_pinned{0};
4444

4545
/**
46-
* @brief Plain sum of all memory stats.
46+
* @brief Sum of all memory stats across the six tracked categories.
4747
*
48-
* Note, this does not take into account the resource hierarchy.
49-
* For example, it's common that workspace resources are allocated from the device global
50-
* resource, so they are effectively counted twice in this function.
48+
* The three resource wrapper classes (dry_run_resources, memory_stats_resources,
49+
* memory_tracking_resources) guarantee that every category is tracked by its own
50+
* independent adaptor: each wrapper force-initializes all resources, captures their
51+
* upstream refs *before* replacing the global device resource, and wraps those
52+
* originals. Workspace and large-workspace allocations therefore bypass the
53+
* device-global tracking adaptor and are counted exactly once, making this sum
54+
* an accurate total when used with stats produced by any of the three wrappers.
5155
*/
5256
[[nodiscard]] inline constexpr auto total() const -> std::size_t
5357
{
@@ -193,6 +197,20 @@ class memory_stats_resources : public resources {
193197

194198
void init()
195199
{
200+
// Independent-counting invariant
201+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202+
// 1. Force-initialize all lazily-created resources (workspace, large workspace,
203+
// pinned, managed) so that their factories resolve against the *original*
204+
// global device MR, not a tracking wrapper we install later.
205+
// 2. Capture every upstream ref while it still points to the original resource.
206+
// 3. Snapshot the resource map to keep the originals alive.
207+
// 4. Only *then* replace the global device resource with the tracking bridge.
208+
// 5. Wrap each captured upstream with a separate statistics_adaptor.
209+
//
210+
// Because step 2 happens before step 4, workspace/lws allocations flow through
211+
// their own adaptor directly to old_device_mr_, bypassing the device bridge.
212+
// Each allocation is therefore counted in exactly one category, and
213+
// memory_stats::total() returns an accurate, non-overlapping sum.
196214
auto* ws = resource::get_workspace_resource(*this);
197215
auto ws_free = resource::get_workspace_free_bytes(*this);
198216
auto ws_upstream = ws->get_upstream_resource();

cpp/include/raft/util/memory_tracking_resources.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,20 @@ class memory_tracking_resources : public resources {
190190

191191
void init()
192192
{
193+
// Independent-counting invariant
194+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
195+
// 1. Force-initialize all lazily-created resources (workspace, large workspace,
196+
// pinned, managed) so that their factories resolve against the *original*
197+
// global device MR, not a tracking wrapper we install later.
198+
// 2. Capture every upstream ref while it still points to the original resource.
199+
// 3. Snapshot the resource map to keep the originals alive.
200+
// 4. Only *then* replace the global device resource with the tracking bridge.
201+
// 5. Wrap each captured upstream with a separate statistics/notifying adaptor.
202+
//
203+
// Because step 2 happens before step 4, workspace/lws allocations flow through
204+
// their own adaptor directly to old_device_mr_, bypassing the device bridge.
205+
// Each allocation is therefore counted in exactly one category, and
206+
// memory_stats::total() returns an accurate, non-overlapping sum.
193207
auto* ws = raft::resource::get_workspace_resource(*this);
194208
auto ws_free = raft::resource::get_workspace_free_bytes(*this);
195209
auto upstream_ref = ws->get_upstream_resource();

cpp/tests/test_utils.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,8 @@ void execute_with_dry_run_check(raft::resources const& res,
362362
resource::sync_stream(stat_res);
363363
auto actual = stat_res.get_bytes_peak();
364364

365-
auto total_dry = dry.device_global + dry.device_managed + dry.host + dry.host_pinned;
366-
auto total_actual =
367-
actual.device_global + actual.device_managed + actual.host + actual.host_pinned;
365+
auto total_dry = dry.total();
366+
auto total_actual = actual.total();
368367

369368
if (dry.device_workspace != actual.device_workspace ||
370369
dry.device_large_workspace != actual.device_large_workspace ||

cpp/tests/util/dry_run_memory_resource.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,159 @@ TEST(DryRunExecute, ExceptionSafety)
232232
EXPECT_FALSE(resource::get_dry_run_flag(res));
233233
}
234234

235+
// ===== Independent-counting tests for dry_run_resources =====
236+
237+
TEST(DryRunResources, IndependentCounting_DefaultWorkspace)
238+
{
239+
raft::resources res;
240+
241+
dry_run_resources dry_res(res);
242+
243+
constexpr std::size_t kWsSize = 1024;
244+
constexpr std::size_t kGlobalSize = 2048;
245+
246+
auto* ws_mr = resource::get_workspace_resource(dry_res);
247+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
248+
249+
auto* dev_mr = rmm::mr::get_current_device_resource();
250+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
251+
252+
auto peak = dry_res.get_bytes_peak();
253+
EXPECT_EQ(peak.device_workspace, kWsSize);
254+
EXPECT_EQ(peak.device_global, kGlobalSize);
255+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
256+
257+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
258+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
259+
}
260+
261+
TEST(DryRunResources, IndependentCounting_WorkspaceSetToGlobal)
262+
{
263+
raft::resources res;
264+
resource::set_workspace_to_global_resource(res);
265+
266+
dry_run_resources dry_res(res);
267+
268+
constexpr std::size_t kWsSize = 1024;
269+
constexpr std::size_t kGlobalSize = 2048;
270+
271+
auto* ws_mr = resource::get_workspace_resource(dry_res);
272+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
273+
274+
auto* dev_mr = rmm::mr::get_current_device_resource();
275+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
276+
277+
auto peak = dry_res.get_bytes_peak();
278+
EXPECT_EQ(peak.device_workspace, kWsSize);
279+
EXPECT_EQ(peak.device_global, kGlobalSize);
280+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
281+
282+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
283+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
284+
}
285+
286+
// ===== Independent-counting tests for memory_stats_resources =====
287+
288+
TEST(MemoryStatsResources, IndependentCounting_DefaultWorkspace)
289+
{
290+
raft::resources res;
291+
292+
memory_stats_resources stat_res(res);
293+
294+
constexpr std::size_t kWsSize = 1024;
295+
constexpr std::size_t kGlobalSize = 2048;
296+
297+
auto* ws_mr = resource::get_workspace_resource(stat_res);
298+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
299+
300+
auto* dev_mr = rmm::mr::get_current_device_resource();
301+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
302+
303+
auto peak = stat_res.get_bytes_peak();
304+
EXPECT_EQ(peak.device_workspace, kWsSize);
305+
EXPECT_EQ(peak.device_global, kGlobalSize);
306+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
307+
308+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
309+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
310+
}
311+
312+
TEST(MemoryStatsResources, IndependentCounting_WorkspaceSetToGlobal)
313+
{
314+
raft::resources res;
315+
resource::set_workspace_to_global_resource(res);
316+
317+
memory_stats_resources stat_res(res);
318+
319+
constexpr std::size_t kWsSize = 1024;
320+
constexpr std::size_t kGlobalSize = 2048;
321+
322+
auto* ws_mr = resource::get_workspace_resource(stat_res);
323+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
324+
325+
auto* dev_mr = rmm::mr::get_current_device_resource();
326+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
327+
328+
auto peak = stat_res.get_bytes_peak();
329+
EXPECT_EQ(peak.device_workspace, kWsSize);
330+
EXPECT_EQ(peak.device_global, kGlobalSize);
331+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
332+
333+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
334+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
335+
}
336+
337+
TEST(MemoryStatsResources, IndependentCounting_PoolWorkspace)
338+
{
339+
raft::resources res;
340+
constexpr std::size_t kPoolLimit = 64UL * 1024UL * 1024UL;
341+
resource::set_workspace_to_pool_resource(res, kPoolLimit);
342+
343+
memory_stats_resources stat_res(res);
344+
345+
constexpr std::size_t kWsSize = 1024;
346+
constexpr std::size_t kGlobalSize = 2048;
347+
348+
auto* ws_mr = resource::get_workspace_resource(stat_res);
349+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
350+
351+
auto* dev_mr = rmm::mr::get_current_device_resource();
352+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
353+
354+
auto peak = stat_res.get_bytes_peak();
355+
EXPECT_EQ(peak.device_workspace, kWsSize);
356+
EXPECT_EQ(peak.device_global, kGlobalSize);
357+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
358+
359+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
360+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
361+
}
362+
363+
// ===== Nested wrappers test =====
364+
365+
TEST(IndependentCounting, NestedDryRunInStats)
366+
{
367+
raft::resources res;
368+
369+
memory_stats_resources stat_res(res);
370+
dry_run_resources dry_res(stat_res);
371+
372+
constexpr std::size_t kWsSize = 1024;
373+
constexpr std::size_t kGlobalSize = 2048;
374+
375+
auto* ws_mr = resource::get_workspace_resource(dry_res);
376+
void* ws_ptr = ws_mr->allocate(rmm::cuda_stream_view{}, kWsSize);
377+
378+
auto* dev_mr = rmm::mr::get_current_device_resource();
379+
void* dev_ptr = dev_mr->allocate(rmm::cuda_stream_view{}, kGlobalSize);
380+
381+
auto peak = dry_res.get_bytes_peak();
382+
EXPECT_EQ(peak.device_workspace, kWsSize);
383+
EXPECT_EQ(peak.device_global, kGlobalSize);
384+
EXPECT_EQ(peak.total(), kWsSize + kGlobalSize);
385+
386+
ws_mr->deallocate(rmm::cuda_stream_view{}, ws_ptr, kWsSize);
387+
dev_mr->deallocate(rmm::cuda_stream_view{}, dev_ptr, kGlobalSize);
388+
}
389+
235390
} // namespace raft::util

0 commit comments

Comments
 (0)