Use context for CUDA external memory DMatrix.#12137
Use context for CUDA external memory DMatrix.#12137trivialfis wants to merge 11 commits intodmlc:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR threads Context through CUDA external-memory ELLPACK/IterativeDMatrix paths so page creation, cache IO, and raw format read/write can use the caller-provided CUDA context (instead of constructing ad-hoc contexts / default-stream assumptions).
Changes:
- Add
Context const*plumbing toIterativeDMatrix::Load,EllpackPageconstruction,EllpackPageRawFormat, and host-cache write paths. - Introduce
MakePage()hooks in format policies to centralize page allocation (supporting context-aware page types). - Update CUDA tests to use the new context-aware constructors/signatures.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/cpp/data/test_iterative_dmatrix.cu | Update IterativeDMatrix::Load call to pass Context. |
| tests/cpp/data/test_ellpack_page_raw_format.cu | Update ELLPACK raw-format tests for context-aware SetCuts and EllpackPage construction. |
| src/data/sparse_page_source.h | Add MakePage() to format policy and use it during prefetch page allocation. |
| src/data/iterative_dmatrix.h | Change IterativeDMatrix::Load signature to accept Context const*. |
| src/data/iterative_dmatrix.cu | Pass context into EllpackPage construction and ELLPACK raw format read/write + Load. |
| src/data/iterative_dmatrix.cc | Update non-CUDA stub Load signature to match header. |
| src/data/gradient_index_page_source.h | Add MakePage() for gradient index format policy. |
| src/data/extmem_quantile_dmatrix.cc | Initialize and use fmat_ctx_ consistently for CPU/GPU init and category sync. |
| src/data/ellpack_page.h | Make EllpackPage require an explicit Context const* for empty-page construction. |
| src/data/ellpack_page.cuh | Make EllpackPageImpl require context and store a Context const*. |
| src/data/ellpack_page.cu | Wire context into EllpackPageImpl construction and use context stream in destructor sync. |
| src/data/ellpack_page.cc | Update non-CUDA dummy EllpackPage ctor signature. |
| src/data/ellpack_page_source.h | Store context in EllpackFormatPolicy, pass it into EllpackPageRawFormat, and create pages with it. |
| src/data/ellpack_page_source.cu | Require context for host-cache write and for constructing pages/impls during fetch/caching. |
| src/data/ellpack_page_raw_format.h | Add Context const* member and update ctor signature. |
| src/data/ellpack_page_raw_format.cu | Use provided context/stream for device copies, cache IO, and sync. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/data/iterative_dmatrix.cu
Outdated
| if (!ellpack_) { | ||
| ellpack_.reset(new EllpackPage()); | ||
| ellpack_.reset(new EllpackPage{&fmat_ctx_}); | ||
| if (ctx->IsCUDA()) { | ||
| this->Info().feature_types.SetDevice(ctx->Device()); | ||
| *ellpack_->Impl() = |
There was a problem hiding this comment.
EllpackPageImpl now stores a raw Context const* (used by its destructor for stream sync). In this code path, the else branch below constructs a temporary CUDA context via ctx->MakeCUDA() and passes its address into EllpackPageImpl, which will leave impl->ctx_ dangling once the branch exits and can trigger use-after-free later (e.g. in ~EllpackPageImpl()). Ensure the context pointer stored inside EllpackPageImpl always outlives the page (e.g. persist the CUDA context in a longer-lived owner like fmat_ctx_, or avoid storing a raw Context* in EllpackPageImpl and store a device/stream handle instead).
| *vec = common::MakeFixedVecWithCudaMalloc<T>(n); | ||
| dh::safe_cuda( | ||
| cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, curt::DefaultStream())); | ||
| cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, ctx->CUDACtx()->Stream())); | ||
| return true; |
There was a problem hiding this comment.
ReadDeviceVec now takes a Context const* and immediately uses ctx->CUDACtx()->Stream() for cudaMemcpyAsync, but it never validates that ctx is non-null / CUDA, nor does it ensure the current CUDA device matches ctx's ordinal. This can lead to invalid stream/device usage when the caller's current device differs from ctx. Add a CHECK(ctx && ctx->IsCUDA()) and set the device (e.g. curt::SetDevice(ctx->Ordinal())) before allocating/copying.
| if (this->param_.prefetch_copy || !has_hmm_ats_) { | ||
| RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer)); | ||
| RET_IF_NOT(ReadDeviceVec(ctx_, fi, &impl->gidx_buffer)); | ||
| } else { |
There was a problem hiding this comment.
EllpackPageRawFormat::Read uses ctx_ (stored in the format object) for device copies and stream sync, but there’s no precondition check that ctx_ is non-null / CUDA or that the active device matches ctx_. Since the constructor now accepts an arbitrary context pointer, add validation (and potentially a device set) at the start of Read/Write to prevent null deref or cross-device stream usage.
RAMitchell
left a comment
There was a problem hiding this comment.
I find it a bit weird passing a raw pointer all the time without any guardrails for its state, but I guess we strongly guarantee that there is only ever one context and it always exists.
That's indeed a significant issue. I would like to use a shared ptr instead at some point. |
The change by #12121 is too significant, starting with a smaller place.