Skip to content

Use context for CUDA external memory DMatrix.#12137

Open
trivialfis wants to merge 11 commits intodmlc:masterfrom
trivialfis:ext-ctx
Open

Use context for CUDA external memory DMatrix.#12137
trivialfis wants to merge 11 commits intodmlc:masterfrom
trivialfis:ext-ctx

Conversation

@trivialfis
Copy link
Copy Markdown
Member

@trivialfis trivialfis commented Apr 1, 2026

The change by #12121 is too significant, starting with a smaller place.

  • Use a sync policy to let the caller handle stream sync. This avoids flaky lifetime management in the dtor.

@trivialfis trivialfis requested review from RAMitchell and Copilot and removed request for Copilot April 1, 2026 18:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR threads Context through CUDA external-memory ELLPACK/IterativeDMatrix paths so page creation, cache IO, and raw format read/write can use the caller-provided CUDA context (instead of constructing ad-hoc contexts / default-stream assumptions).

Changes:

  • Add Context const* plumbing to IterativeDMatrix::Load, EllpackPage construction, EllpackPageRawFormat, and host-cache write paths.
  • Introduce MakePage() hooks in format policies to centralize page allocation (supporting context-aware page types).
  • Update CUDA tests to use the new context-aware constructors/signatures.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/cpp/data/test_iterative_dmatrix.cu Update IterativeDMatrix::Load call to pass Context.
tests/cpp/data/test_ellpack_page_raw_format.cu Update ELLPACK raw-format tests for context-aware SetCuts and EllpackPage construction.
src/data/sparse_page_source.h Add MakePage() to format policy and use it during prefetch page allocation.
src/data/iterative_dmatrix.h Change IterativeDMatrix::Load signature to accept Context const*.
src/data/iterative_dmatrix.cu Pass context into EllpackPage construction and ELLPACK raw format read/write + Load.
src/data/iterative_dmatrix.cc Update non-CUDA stub Load signature to match header.
src/data/gradient_index_page_source.h Add MakePage() for gradient index format policy.
src/data/extmem_quantile_dmatrix.cc Initialize and use fmat_ctx_ consistently for CPU/GPU init and category sync.
src/data/ellpack_page.h Make EllpackPage require an explicit Context const* for empty-page construction.
src/data/ellpack_page.cuh Make EllpackPageImpl require context and store a Context const*.
src/data/ellpack_page.cu Wire context into EllpackPageImpl construction and use context stream in destructor sync.
src/data/ellpack_page.cc Update non-CUDA dummy EllpackPage ctor signature.
src/data/ellpack_page_source.h Store context in EllpackFormatPolicy, pass it into EllpackPageRawFormat, and create pages with it.
src/data/ellpack_page_source.cu Require context for host-cache write and for constructing pages/impls during fetch/caching.
src/data/ellpack_page_raw_format.h Add Context const* member and update ctor signature.
src/data/ellpack_page_raw_format.cu Use provided context/stream for device copies, cache IO, and sync.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 109 to 113
if (!ellpack_) {
ellpack_.reset(new EllpackPage());
ellpack_.reset(new EllpackPage{&fmat_ctx_});
if (ctx->IsCUDA()) {
this->Info().feature_types.SetDevice(ctx->Device());
*ellpack_->Impl() =
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EllpackPageImpl now stores a raw Context const* (used by its destructor for stream sync). In this code path, the else branch below constructs a temporary CUDA context via ctx->MakeCUDA() and passes its address into EllpackPageImpl, which will leave impl->ctx_ dangling once the branch exits and can trigger use-after-free later (e.g. in ~EllpackPageImpl()). Ensure the context pointer stored inside EllpackPageImpl always outlives the page (e.g. persist the CUDA context in a longer-lived owner like fmat_ctx_, or avoid storing a raw Context* in EllpackPageImpl and store a device/stream handle instead).

Copilot uses AI. Check for mistakes.
Comment on lines 43 to 46
*vec = common::MakeFixedVecWithCudaMalloc<T>(n);
dh::safe_cuda(
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, curt::DefaultStream()));
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
return true;
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReadDeviceVec now takes a Context const* and immediately uses ctx->CUDACtx()->Stream() for cudaMemcpyAsync, but it never validates that ctx is non-null / CUDA, nor does it ensure the current CUDA device matches ctx's ordinal. This can lead to invalid stream/device usage when the caller's current device differs from ctx. Add a CHECK(ctx && ctx->IsCUDA()) and set the device (e.g. curt::SetDevice(ctx->Ordinal())) before allocating/copying.

Copilot uses AI. Check for mistakes.
Comment on lines 64 to 66
if (this->param_.prefetch_copy || !has_hmm_ats_) {
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
RET_IF_NOT(ReadDeviceVec(ctx_, fi, &impl->gidx_buffer));
} else {
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EllpackPageRawFormat::Read uses ctx_ (stored in the format object) for device copies and stream sync, but there’s no precondition check that ctx_ is non-null / CUDA or that the active device matches ctx_. Since the constructor now accepts an arbitrary context pointer, add validation (and potentially a device set) at the start of Read/Write to prevent null deref or cross-device stream usage.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

@RAMitchell RAMitchell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit weird passing a raw pointer all the time without any guardrails for its state, but I guess we strongly guarantee that there is only ever one context and it always exists.

@trivialfis
Copy link
Copy Markdown
Member Author

I find it a bit weird passing a raw pointer all the time without any guardrails for its state

That's indeed a significant issue. I would like to use a shared ptr instead at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants