-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[wip] Invalidate the DMatrix cache after modification. #11885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,10 +150,10 @@ class MetaInfo { | |
| * \param fo The output stream. | ||
| */ | ||
| void SaveBinary(dmlc::Stream* fo) const; | ||
| /*! | ||
| * \brief Set information in the meta info with array interface. | ||
| * \param key The key of the information. | ||
| * \param interface_str String representation of json format array interface. | ||
| /** | ||
| * @brief Set data in the meta info with array interface. | ||
|
||
| * @param key The key of the information. | ||
| * @param interface_str String representation of json format array interface. | ||
| */ | ||
| void SetInfo(Context const& ctx, StringView key, StringView interface_str); | ||
|
|
||
|
|
@@ -218,6 +218,9 @@ class MetaInfo { | |
| * @brief Setter for categories. | ||
| */ | ||
| void Cats(std::shared_ptr<CatContainer> cats); | ||
| // Flag to indicate whether one needs to refresh the DMatrix cache. | ||
| void SetReadFlag(bool has_been_read) { this->has_been_read_ = has_been_read; } | ||
| [[nodiscard]] bool HasBeenRead() const { return this->has_been_read_; } | ||
|
|
||
| private: | ||
| void SetInfoFromHost(Context const* ctx, StringView key, Json arr); | ||
|
|
@@ -226,6 +229,7 @@ class MetaInfo { | |
| /*! \brief argsort of labels */ | ||
| mutable std::vector<size_t> label_order_cache_; | ||
| bool has_categorical_{false}; | ||
| bool has_been_read_{false}; | ||
|
||
|
|
||
| std::shared_ptr<CatContainer> cats_; | ||
| }; | ||
|
|
@@ -740,6 +744,7 @@ class DMatrix { | |
|
|
||
| template <> | ||
| inline BatchSet<SparsePage> DMatrix::GetBatches() { | ||
| this->Info().SetReadFlag(true); | ||
| return GetRowBatches(); | ||
| } | ||
|
|
||
|
|
@@ -760,31 +765,37 @@ inline bool DMatrix::PageExists<SparsePage>() const { | |
|
|
||
| template <> | ||
| inline BatchSet<SparsePage> DMatrix::GetBatches(Context const*) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetRowBatches(); | ||
| } | ||
|
|
||
| template <> | ||
| inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetColumnBatches(ctx); | ||
| } | ||
|
|
||
| template <> | ||
| inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetSortedColumnBatches(ctx); | ||
| } | ||
|
|
||
| template <> | ||
| inline BatchSet<EllpackPage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetEllpackBatches(ctx, param); | ||
| } | ||
|
|
||
| template <> | ||
| inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetGradientIndex(ctx, param); | ||
| } | ||
|
|
||
| template <> | ||
| inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { | ||
| this->Info().SetReadFlag(true); | ||
| return GetExtBatches(ctx, param); | ||
| } | ||
| } // namespace xgboost | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -513,6 +513,8 @@ void CopyTensorInfoImpl(Context const* ctx, Json arr_interface, linalg::Tensor<T | |
| } // namespace | ||
|
|
||
| void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) { | ||
| this->SetReadFlag(false); | ||
|
||
|
|
||
| Json j_interface = Json::Load(interface_str); | ||
| bool is_cuda{false}; | ||
| if (IsA<Array>(j_interface)) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache invalidation logic may be incorrect. The condition checks if the DMatrix has NOT been read (!HasBeenRead()), which seems backwards. According to the comment "Re-new the cache if this has never been read", the logic appears to invalidate caches that haven't been read yet. However, the intent should be to invalidate caches for DMatrix objects that have been modified (where has_been_read_ was reset to false after being true). Consider whether the flag should be renamed to something like "is_modified_" or "needs_cache_refresh_" to better reflect its purpose, or verify that the invalidation condition is correct.