Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/xgboost/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class DMatrixCache {
auto p_fmat = queue_.front();
auto it = container_.find(p_fmat);
CHECK(it != container_.cend());
if (it->second.ref.expired()) {
// Re-new the cache if this has never been read.
if (it->second.ref.expired() || !it->second.ref.lock()->Info().HasBeenRead()) {
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache invalidation logic may be incorrect. The condition checks if the DMatrix has NOT been read (!HasBeenRead()), which seems backwards. According to the comment "Re-new the cache if this has never been read", the logic appears to invalidate caches that haven't been read yet. However, the intent should be to invalidate caches for DMatrix objects that have been modified (where has_been_read_ was reset to false after being true). Consider whether the flag should be renamed to something like "is_modified_" or "needs_cache_refresh_" to better reflect its purpose, or verify that the invalidation condition is correct.

Copilot uses AI. Check for mistakes.
expired.push_back(it->first);
} else {
remained.push(it->first);
Expand All @@ -101,7 +102,7 @@ class DMatrixCache {

void ClearExcess() {
this->CheckConsistent();
// clear half of the entries to prevent repeatingly clearing cache.
// clear half of the entries to prevent repeatedly clearing cache.
std::size_t half_size = max_size_ / 2;
while (queue_.size() >= half_size && !queue_.empty()) {
auto p_fmat = queue_.front();
Expand Down
19 changes: 15 additions & 4 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ class MetaInfo {
* \param fo The output stream.
*/
void SaveBinary(dmlc::Stream* fo) const;
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
/**
* @brief Set data in the meta info with array interface.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation change from "information" to "data" may be inaccurate. The SetInfo method accepts various types of metadata (labels, weights, base_margin, etc.) through the array interface, not just "data". The term "information" more accurately describes what this method sets. Consider reverting this documentation change or clarifying that "data" refers to metadata fields.

Copilot uses AI. Check for mistakes.
* @param key The key of the information.
* @param interface_str String representation of json format array interface.
*/
void SetInfo(Context const& ctx, StringView key, StringView interface_str);

Expand Down Expand Up @@ -218,6 +218,9 @@ class MetaInfo {
* @brief Setter for categories.
*/
void Cats(std::shared_ptr<CatContainer> cats);
// Flag to indicate whether one needs to refresh the DMatrix cache.
void SetReadFlag(bool has_been_read) { this->has_been_read_ = has_been_read; }
[[nodiscard]] bool HasBeenRead() const { return this->has_been_read_; }

private:
void SetInfoFromHost(Context const* ctx, StringView key, Json arr);
Expand All @@ -226,6 +229,7 @@ class MetaInfo {
/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
bool has_categorical_{false};
bool has_been_read_{false};
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The has_been_read_ member variable is not thread-safe. The flag is being read and written from multiple locations without synchronization, which could lead to data races in multi-threaded scenarios. Consider either making this field std::atomic or adding proper synchronization mechanisms, especially since the cache operations in cache.h are already protected by a mutex.

Copilot uses AI. Check for mistakes.

std::shared_ptr<CatContainer> cats_;
};
Expand Down Expand Up @@ -740,6 +744,7 @@ class DMatrix {

template <>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
this->Info().SetReadFlag(true);
return GetRowBatches();
}

Expand All @@ -760,31 +765,37 @@ inline bool DMatrix::PageExists<SparsePage>() const {

template <>
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
this->Info().SetReadFlag(true);
return GetRowBatches();
}

template <>
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
this->Info().SetReadFlag(true);
return GetColumnBatches(ctx);
}

template <>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
this->Info().SetReadFlag(true);
return GetSortedColumnBatches(ctx);
}

template <>
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
this->Info().SetReadFlag(true);
return GetEllpackBatches(ctx, param);
}

template <>
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
this->Info().SetReadFlag(true);
return GetGradientIndex(ctx, param);
}

template <>
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
this->Info().SetReadFlag(true);
return GetExtBatches(ctx, param);
}
} // namespace xgboost
Expand Down
2 changes: 2 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ void CopyTensorInfoImpl(Context const* ctx, Json arr_interface, linalg::Tensor<T
} // namespace

void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
this->SetReadFlag(false);
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SetReadFlag(false) call should be placed after the metadata modification is complete, not before. If the modification fails (e.g., Json::Load throws an exception), the flag will be incorrectly set to false even though no actual modification occurred. Consider moving this call to the end of the function, after all validation and modification operations have completed successfully.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The has_been_read_ flag should be reset when MetaInfo is modified through methods other than SetInfo(). Consider adding SetReadFlag(false) calls to other modification methods such as SetFeatureInfo(), Extend(), Clear(), LoadBinary(), and Cats(std::shared_ptr). Without this, cache invalidation will not work correctly when data is modified through these alternative paths.

Copilot uses AI. Check for mistakes.

Json j_interface = Json::Load(interface_str);
bool is_cuda{false};
if (IsA<Array>(j_interface)) {
Expand Down
Loading