Optional context for the host device vector.#12121
Optional context for the host device vector.#12121trivialfis wants to merge 21 commits intodmlc:masterfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces an optional Context parameter to HostDeviceVector and linalg::Tensor APIs to enable future stream-aware CUDA behavior (custom streams / device-ordinal handling) and begins updating call sites to pass context through.
Changes:
- Extend
HostDeviceVectorandlinalg::Tensorinterfaces to accept an optionalContext const*. - Update GPU transfer/sync paths in
HostDeviceVectorto use a stream fromCUDAContextwhen provided. - Adjust predictor and CUDA unit tests to pass
Contextinto the updated APIs.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| include/xgboost/host_device_vector.h | Adds optional Context parameters to most HostDeviceVector APIs. |
| src/common/host_device_vector.cu | Wires stream selection for copies/fill and updates method signatures to accept context. |
| src/common/host_device_vector.cc | Updates non-CUDA fallback implementation to match new signatures. |
| include/xgboost/linalg.h | Threads optional context through Tensor/views and MakeTensorView helpers. |
| src/predictor/predictor.cc | Passes predictor context into HostDeviceVector operations for predictions init/fill/copy. |
| tests/cpp/common/test_host_device_vector.cu | Updates tests to construct and pass Context through updated APIs. |
| tests/cpp/common/test_linalg.h | Updates linear algebra dispatch helper to pass Context into views/host views. |
| tests/cpp/common/test_linalg.cu | Updates CUDA linalg tests to pass context into Tensor and View calls. |
Comments suppressed due to low confidence (1)
src/common/host_device_vector.cu:43
HostDeviceVectorImpl(size_t, T, DeviceOrd, CUDAContext const* ctx)takesctxbut doesn't use it. This can trigger-Wunused-parameterwarnings (often treated as errors) and is confusing given the new stream-aware API. Consider removing the parameter name (or marking it[[maybe_unused]]/ casting to void), or actually using it where appropriate.
HostDeviceVectorImpl(size_t size, T v, DeviceOrd device, CUDAContext const* ctx)
: device_(device) {
if (device.IsCUDA()) {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
data_d_->resize(size, v);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 33 out of 33 changed files in this pull request and generated 6 comments.
Comments suppressed due to low confidence (2)
src/common/device_compression.cu:372
- In
CompressSnappy,DeviceUVectornow requires a stream for mutations.p_out->resize(...)(and the earlierp_out->clear()in the same function) should take the CUDA context stream to avoid compilation issues and to ensure allocations happen on the intended stream for cudaMallocAsync-backed allocators.
SafeNvComp(nvcompBatchedSnappyCompressGetMaxOutputChunkSize(
std::min(max_in_nbytes, chunk_size), nvcomp_batched_snappy_opts, &max_out_nbytes));
p_out->resize(max_out_nbytes * n_chunks);
std::vector<void*> h_out_ptrs(n_chunks);
src/data/ellpack_page_source.cu:235
d_pageis allocated usingctx.CUDACtx()->Stream(), but the subsequentcudaMemcpyAsyncto/from that buffer omits the stream parameter (defaults to legacy stream). This can introduce cross-stream races when a custom stream is used; passctx.CUDACtx()->Stream()to the memcpy (and keep stream usage consistent for the earlier host-cache memcpy as well).
auto d_page = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, remaining);
if (remaining > 0) {
dh::safe_cuda(cudaMemcpyAsync(d_page.data(), old_impl->gidx_buffer.data() + n_bytes,
remaining, cudaMemcpyDefault));
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
a790207 to
dbac855
Compare
This is the first PR to add an optional parameter to the
HostDeviceVectorand theTensorclass.ref: #12122