Skip to content

Commit 410cf1d

Browse files
committed
Invalidate the cache.
1 parent 17d957c commit 410cf1d

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

include/xgboost/cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class DMatrixCache {
7878
auto p_fmat = queue_.front();
7979
auto it = container_.find(p_fmat);
8080
CHECK(it != container_.cend());
81-
if (it->second.ref.expired()) {
81+
// Re-new the cache if this has never been read.
82+
if (it->second.ref.expired() || !it->second.ref.lock()->Info().HasBeenRead()) {
8283
expired.push_back(it->first);
8384
} else {
8485
remained.push(it->first);

include/xgboost/data.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ class MetaInfo {
218218
* @brief Setter for categories.
219219
*/
220220
void Cats(std::shared_ptr<CatContainer> cats);
221-
222-
void SetReadFlag() { this->has_been_read_ = true; }
221+
// Flag to indicate whether one needs to refresh the DMatrix cache.
222+
void SetReadFlag(bool has_been_read) { this->has_been_read_ = has_been_read; }
223+
[[nodiscard]] bool HasBeenRead() const { return this->has_been_read_; }
223224

224225
private:
225226
void SetInfoFromHost(Context const* ctx, StringView key, Json arr);
@@ -743,7 +744,7 @@ class DMatrix {
743744

744745
template <>
745746
inline BatchSet<SparsePage> DMatrix::GetBatches() {
746-
this->Info().SetReadFlag();
747+
this->Info().SetReadFlag(true);
747748
return GetRowBatches();
748749
}
749750

@@ -764,37 +765,37 @@ inline bool DMatrix::PageExists<SparsePage>() const {
764765

765766
template <>
766767
inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) {
767-
this->Info().SetReadFlag();
768+
this->Info().SetReadFlag(true);
768769
return GetRowBatches();
769770
}
770771

771772
template <>
772773
inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
773-
this->Info().SetReadFlag();
774+
this->Info().SetReadFlag(true);
774775
return GetColumnBatches(ctx);
775776
}
776777

777778
template <>
778779
inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
779-
this->Info().SetReadFlag();
780+
this->Info().SetReadFlag(true);
780781
return GetSortedColumnBatches(ctx);
781782
}
782783

783784
template <>
784785
inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
785-
this->Info().SetReadFlag();
786+
this->Info().SetReadFlag(true);
786787
return GetEllpackBatches(ctx, param);
787788
}
788789

789790
template <>
790791
inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
791-
this->Info().SetReadFlag();
792+
this->Info().SetReadFlag(true);
792793
return GetGradientIndex(ctx, param);
793794
}
794795

795796
template <>
796797
inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
797-
this->Info().SetReadFlag();
798+
this->Info().SetReadFlag(true);
798799
return GetExtBatches(ctx, param);
799800
}
800801
} // namespace xgboost

src/data/data.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ void CopyTensorInfoImpl(Context const* ctx, Json arr_interface, linalg::Tensor<T
513513
} // namespace
514514

515515
void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
516-
CHECK(!this->has_been_read_)
517-
<< "Don't modify the DMatrix it's used. This violates some caches in XGBoost.";
516+
this->SetReadFlag(false);
517+
518518
Json j_interface = Json::Load(interface_str);
519519
bool is_cuda{false};
520520
if (IsA<Array>(j_interface)) {

0 commit comments

Comments
 (0)