Implement the functionality of the hyperparameter alpha_block with test
This commit is contained in:
parent
18844c7da7
commit
3d6b4f0614
@ -92,7 +92,25 @@ namespace bayesnet {
|
||||
model->fit(dataset, features, className, states, weights_, smoothing);
|
||||
alpha_t = 0.0;
|
||||
if (!block_update) {
|
||||
auto ypred = model->predict(X_train);
|
||||
torch::Tensor ypred;
|
||||
if (alpha_block) {
|
||||
//
|
||||
// Compute the prediction with the current ensemble + model
|
||||
//
|
||||
// Add the model to the ensemble
|
||||
n_models++;
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(1);
|
||||
// Compute the prediction
|
||||
ypred = predict(X_train);
|
||||
// Remove the model from the ensemble
|
||||
model = std::move(models.back());
|
||||
models.pop_back();
|
||||
significanceModels.pop_back();
|
||||
n_models--;
|
||||
} else {
|
||||
ypred = model->predict(X_train);
|
||||
}
|
||||
// Step 3.1: Compute the classifier amout of say
|
||||
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
||||
}
|
||||
|
@ -130,6 +130,8 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
||||
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
||||
{ { "alpha_block", true }, { "block_update", true } },
|
||||
{ { "bisection", false }, { "block_update", true } },
|
||||
};
|
||||
for (const auto& hyper : bad_hyper_fit.items()) {
|
||||
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||
@ -137,7 +139,6 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("Bisection Best", "[BoostAODE]")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
@ -180,7 +181,6 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
|
||||
auto score_last = clf.score(raw.X_test, raw.y_test);
|
||||
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
|
||||
}
|
||||
|
||||
TEST_CASE("Block Update", "[BoostAODE]")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
@ -210,4 +210,19 @@ TEST_CASE("Block Update", "[BoostAODE]")
|
||||
// std::cout << note << std::endl;
|
||||
// }
|
||||
// std::cout << "Score " << score << std::endl;
|
||||
}
|
||||
TEST_CASE("Alphablock", "[BoostAODE]")
|
||||
{
|
||||
auto clf_alpha = bayesnet::BoostAODE();
|
||||
auto clf_no_alpha = bayesnet::BoostAODE();
|
||||
auto raw = RawDatasets("diabetes", true);
|
||||
clf_alpha.setHyperparameters({
|
||||
{"alpha_block", true},
|
||||
});
|
||||
clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test);
|
||||
auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test);
|
||||
REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon));
|
||||
REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon));
|
||||
}
|
Loading…
Reference in New Issue
Block a user