Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,14 @@ class ObjFunction : public Configurable {
* @param name Name of the objective.
* @param ctx Pointer to the context.
*/
static ObjFunction* Create(const std::string& name, Context const* ctx);
static ObjFunction* Create(const std::string& name, Context const* ctx, Args const& args = {});
};

/*!
* \brief Registry entry for objective factory functions.
*/
struct ObjFunctionReg
: public dmlc::FunctionRegEntryBase<ObjFunctionReg,
std::function<ObjFunction* ()> > {
: public dmlc::FunctionRegEntryBase<ObjFunctionReg, std::function<ObjFunction*(Args const&)> > {
};

/*!
Expand All @@ -154,14 +153,14 @@ struct ObjFunctionReg
* // example of registering a objective
* XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:squarederror")
* .describe("Linear regression objective")
* .set_body([]() {
* .set_body([](Args const&) {
* return new RegLossObj(LossType::kLinearSquare);
* });
* \endcode
*/
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
__make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
#define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg& \
__make_##ObjFunctionReg##_##UniqueId##__ = \
::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_H_
20 changes: 9 additions & 11 deletions plugin/example/custom_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ struct MyLogisticParam : public XGBoostParameter<MyLogisticParam> {
float scale_neg_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(MyLogisticParam) {
DMLC_DECLARE_FIELD(scale_neg_weight).set_default(1.0f).set_lower_bound(0.0f)
DMLC_DECLARE_FIELD(scale_neg_weight)
.set_default(1.0f)
.set_lower_bound(0.0f)
.describe("Scale the weight of negative examples by this factor");
}
};
Expand Down Expand Up @@ -53,12 +55,10 @@ class MyLogistic : public ObjFunction {
out_gpair_h(i) = GradientPair(grad, hess);
}
}
[[nodiscard]] const char* DefaultEvalMetric() const override {
return "logloss";
}
void PredTransform(HostDeviceVector<float> *io_preds) const override {
[[nodiscard]] const char* DefaultEvalMetric() const override { return "logloss"; }
void PredTransform(HostDeviceVector<float>* io_preds) const override {
// transform margin value to probability.
std::vector<float> &preds = io_preds->HostVector();
std::vector<float>& preds = io_preds->HostVector();
for (auto& pred : preds) {
pred = 1.0f / (1.0f + std::exp(-pred));
}
Expand All @@ -77,9 +77,7 @@ class MyLogistic : public ObjFunction {
out["my_logistic_param"] = ToJson(param_);
}

void LoadConfig(Json const& in) override {
FromJson(in["my_logistic_param"], &param_);
}
void LoadConfig(Json const& in) override { FromJson(in["my_logistic_param"], &param_); }

private:
MyLogisticParam param_;
Expand All @@ -88,7 +86,7 @@ class MyLogistic : public ObjFunction {
// Finally register the objective function.
// After it succeeds you can try use xgboost with objective=mylogistic
XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic")
.describe("User defined logistic regression plugin")
.set_body([]() { return new MyLogistic(); });
.describe("User defined logistic regression plugin")
.set_body([](Args const&) { return new MyLogistic(); });

} // namespace xgboost::obj
10 changes: 5 additions & 5 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -847,17 +847,17 @@ class LearnerConfiguration : public Intercept {
// Rename one of them once binary IO is gone.
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
}
if (obj_ == nullptr || tparam_.objective != old.objective) {
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}

bool has_nc{cfg_.find("num_class") != cfg_.cend()};
// Inject num_class into configuration.
// FIXME(jiamingy): Remove the duplicated parameter in softmax
cfg_["num_class"] = std::to_string(mparam_.num_class);
auto& args = *p_args;
args = {cfg_.cbegin(), cfg_.cend()}; // renew
obj_->Configure(args);
if (obj_ == nullptr || tparam_.objective != old.objective) {
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_, args));
} else {
obj_->Configure(args);
}
if (!has_nc) {
cfg_.erase("num_class");
}
Expand Down
86 changes: 40 additions & 46 deletions src/objective/aft_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu);

class AFTObj : public ObjFunction {
public:
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
}
explicit AFTObj(Args const& args) { param_.UpdateAllowUnknown(args); }
AFTObj() = default;

void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }

ObjInfo Task() const override { return ObjInfo::kSurvival; }

Expand All @@ -42,27 +43,24 @@ class AFTObj : public ObjFunction {
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, DeviceOrd device,
bool is_null_weight, float aft_loss_distribution_scale) {
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels_lower_bound,
common::Span<const bst_float> _labels_upper_bound,
common::Span<const bst_float> _weights) {
const double pred = static_cast<double>(_preds[_idx]);
const double label_lower_bound = static_cast<double>(_labels_lower_bound[_idx]);
const double label_upper_bound = static_cast<double>(_labels_upper_bound[_idx]);
const float grad = static_cast<float>(
AFTLoss<Distribution>::Gradient(label_lower_bound, label_upper_bound,
pred, aft_loss_distribution_scale));
const float hess = static_cast<float>(
AFTLoss<Distribution>::Hessian(label_lower_bound, label_upper_bound,
pred, aft_loss_distribution_scale));
const bst_float w = is_null_weight ? 1.0f : _weights[_idx];
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_);
[=] XGBOOST_DEVICE(size_t _idx, common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels_lower_bound,
common::Span<const bst_float> _labels_upper_bound,
common::Span<const bst_float> _weights) {
const double pred = static_cast<double>(_preds[_idx]);
const double label_lower_bound = static_cast<double>(_labels_lower_bound[_idx]);
const double label_upper_bound = static_cast<double>(_labels_upper_bound[_idx]);
const float grad = static_cast<float>(AFTLoss<Distribution>::Gradient(
label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale));
const float hess = static_cast<float>(AFTLoss<Distribution>::Hessian(
label_lower_bound, label_upper_bound, pred, aft_loss_distribution_scale));
const bst_float w = is_null_weight ? 1.0f : _weights[_idx];
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device)
.Eval(out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_);
}

void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int /*iter*/,
Expand All @@ -77,28 +75,28 @@ class AFTObj : public ObjFunction {
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
<< "Number of weights should be equal to number of data points.";
}

switch (param_.aft_loss_distribution) {
case common::ProbabilityDistributionType::kNormal:
GetGradientImpl<common::NormalDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kLogistic:
GetGradientImpl<common::LogisticDistribution>(preds, info, out_gpair, ndata, device,
case common::ProbabilityDistributionType::kNormal:
GetGradientImpl<common::NormalDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kExtreme:
GetGradientImpl<common::ExtremeDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
default:
LOG(FATAL) << "Unrecognized distribution";
break;
case common::ProbabilityDistributionType::kLogistic:
GetGradientImpl<common::LogisticDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
case common::ProbabilityDistributionType::kExtreme:
GetGradientImpl<common::ExtremeDistribution>(preds, info, out_gpair, ndata, device,
is_null_weight, aft_loss_distribution_scale);
break;
default:
LOG(FATAL) << "Unrecognized distribution";
}
}

void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
// Trees give us a prediction in log scale, so exponentiate
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
Expand All @@ -120,19 +118,15 @@ class AFTObj : public ObjFunction {
});
}

const char* DefaultEvalMetric() const override {
return "aft-nloglik";
}
const char* DefaultEvalMetric() const override { return "aft-nloglik"; }

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("survival:aft");
out["aft_loss_param"] = ToJson(param_);
}

void LoadConfig(Json const& in) override {
FromJson(in["aft_loss_param"], &param_);
}
void LoadConfig(Json const& in) override { FromJson(in["aft_loss_param"], &param_); }
Json DefaultMetricConfig() const override {
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
Expand All @@ -147,7 +141,7 @@ class AFTObj : public ObjFunction {
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
.describe("AFT loss function")
.set_body([]() { return new AFTObj(); });
.set_body([](Args const& args) { return new AFTObj{args}; });

} // namespace obj
} // namespace xgboost
2 changes: 1 addition & 1 deletion src/objective/hinge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,6 @@ class HingeObj : public FitIntercept {
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
.set_body([](Args const &) { return new HingeObj(); });

} // namespace xgboost::obj
15 changes: 12 additions & 3 deletions src/objective/lambdarank_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ class LambdaRankObj : public FitIntercept {
}

public:
explicit LambdaRankObj(Args const& args) { param_.UpdateAllowUnknown(args); }
LambdaRankObj() = default;

void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
Expand Down Expand Up @@ -327,6 +330,8 @@ class LambdaRankObj : public FitIntercept {

class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
public:
using LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache>::LambdaRankObj;

template <bool unbiased, bool exp_gain>
void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span<float const> g_predt,
linalg::VectorView<float const> g_label, float w,
Expand Down Expand Up @@ -474,6 +479,8 @@ void MAPStat(Context const* ctx, linalg::VectorView<float const> label,

class LambdaRankMAP : public LambdaRankObj<LambdaRankMAP, ltr::MAPCache> {
public:
using LambdaRankObj<LambdaRankMAP, ltr::MAPCache>::LambdaRankObj;

void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
if (ctx_->IsCUDA()) {
Expand Down Expand Up @@ -574,6 +581,8 @@ void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector<flo
*/
class LambdaRankPairwise : public LambdaRankObj<LambdaRankPairwise, ltr::RankingCache> {
public:
using LambdaRankObj<LambdaRankPairwise, ltr::RankingCache>::LambdaRankObj;

void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, linalg::Matrix<GradientPair>* out_gpair) {
if (ctx_->IsCUDA()) {
Expand Down Expand Up @@ -657,15 +666,15 @@ void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVecto

XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
.describe("LambdaRank with NDCG loss as objective")
.set_body([]() { return new LambdaRankNDCG{}; });
.set_body([](Args const& args) { return new LambdaRankNDCG{args}; });

XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name())
.describe("LambdaRank with RankNet loss as objective")
.set_body([]() { return new LambdaRankPairwise{}; });
.set_body([](Args const& args) { return new LambdaRankPairwise{args}; });

XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name())
.describe("LambdaRank with MAP loss as objective.")
.set_body([]() { return new LambdaRankMAP{}; });
.set_body([](Args const& args) { return new LambdaRankMAP{args}; });

DMLC_REGISTRY_FILE_TAG(lambdarank_obj);
} // namespace xgboost::obj
9 changes: 7 additions & 2 deletions src/objective/multiclass_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ void ValidateLabel(Context const* ctx, MetaInfo const& info, std::int64_t n_clas
class SoftmaxMultiClassObj : public ObjFunction {
public:
explicit SoftmaxMultiClassObj(bool output_prob) : output_prob_(output_prob) {}
SoftmaxMultiClassObj(bool output_prob, Args const& args) : output_prob_(output_prob) {
if (!args.empty()) {
param_.UpdateAllowUnknown(args);
}
}

void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }

Expand Down Expand Up @@ -233,9 +238,9 @@ DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);

XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")
.describe("Softmax for multi-class classification, output class index.")
.set_body([]() { return new SoftmaxMultiClassObj(false); });
.set_body([](Args const& args) { return new SoftmaxMultiClassObj(false, args); });

XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob")
.describe("Softmax for multi-class classification, output probability distribution.")
.set_body([]() { return new SoftmaxMultiClassObj(true); });
.set_body([](Args const& args) { return new SoftmaxMultiClassObj(true, args); });
} // namespace xgboost::obj
9 changes: 4 additions & 5 deletions src/objective/objective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);

namespace xgboost {
// implement factory functions
ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx, Args const& args) {
std::string obj_name = name;
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name);
auto* e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name);
if (e == nullptr) {
std::stringstream ss;
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
ss << "Objective candidate: " << entry->name << "\n";
}
LOG(FATAL) << "Unknown objective function: `" << name << "`\n"
<< ss.str();
LOG(FATAL) << "Unknown objective function: `" << name << "`\n" << ss.str();
}
auto pobj = (e->body)();
auto pobj = (e->body)(args);
pobj->ctx_ = ctx;
return pobj;
}
Expand Down
Loading
Loading