Implement the functionality of the hyperparameter alpha_block with test
This commit is contained in:
parent
18844c7da7
commit
3d6b4f0614
@ -92,7 +92,25 @@ namespace bayesnet {
|
|||||||
model->fit(dataset, features, className, states, weights_, smoothing);
|
model->fit(dataset, features, className, states, weights_, smoothing);
|
||||||
alpha_t = 0.0;
|
alpha_t = 0.0;
|
||||||
if (!block_update) {
|
if (!block_update) {
|
||||||
auto ypred = model->predict(X_train);
|
torch::Tensor ypred;
|
||||||
|
if (alpha_block) {
|
||||||
|
//
|
||||||
|
// Compute the prediction with the current ensemble + model
|
||||||
|
//
|
||||||
|
// Add the model to the ensemble
|
||||||
|
n_models++;
|
||||||
|
models.push_back(std::move(model));
|
||||||
|
significanceModels.push_back(1);
|
||||||
|
// Compute the prediction
|
||||||
|
ypred = predict(X_train);
|
||||||
|
// Remove the model from the ensemble
|
||||||
|
model = std::move(models.back());
|
||||||
|
models.pop_back();
|
||||||
|
significanceModels.pop_back();
|
||||||
|
n_models--;
|
||||||
|
} else {
|
||||||
|
ypred = model->predict(X_train);
|
||||||
|
}
|
||||||
// Step 3.1: Compute the classifier amout of say
|
// Step 3.1: Compute the classifier amout of say
|
||||||
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
||||||
}
|
}
|
||||||
|
@ -130,6 +130,8 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
|||||||
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
||||||
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
||||||
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
||||||
|
{ { "alpha_block", true }, { "block_update", true } },
|
||||||
|
{ { "bisection", false }, { "block_update", true } },
|
||||||
};
|
};
|
||||||
for (const auto& hyper : bad_hyper_fit.items()) {
|
for (const auto& hyper : bad_hyper_fit.items()) {
|
||||||
INFO("BoostAODE hyper: " << hyper.value().dump());
|
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||||
@ -137,7 +139,6 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
|||||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("Bisection Best", "[BoostAODE]")
|
TEST_CASE("Bisection Best", "[BoostAODE]")
|
||||||
{
|
{
|
||||||
auto clf = bayesnet::BoostAODE();
|
auto clf = bayesnet::BoostAODE();
|
||||||
@ -180,7 +181,6 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
|
|||||||
auto score_last = clf.score(raw.X_test, raw.y_test);
|
auto score_last = clf.score(raw.X_test, raw.y_test);
|
||||||
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
|
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("Block Update", "[BoostAODE]")
|
TEST_CASE("Block Update", "[BoostAODE]")
|
||||||
{
|
{
|
||||||
auto clf = bayesnet::BoostAODE();
|
auto clf = bayesnet::BoostAODE();
|
||||||
@ -210,4 +210,19 @@ TEST_CASE("Block Update", "[BoostAODE]")
|
|||||||
// std::cout << note << std::endl;
|
// std::cout << note << std::endl;
|
||||||
// }
|
// }
|
||||||
// std::cout << "Score " << score << std::endl;
|
// std::cout << "Score " << score << std::endl;
|
||||||
|
}
|
||||||
|
TEST_CASE("Alphablock", "[BoostAODE]")
|
||||||
|
{
|
||||||
|
auto clf_alpha = bayesnet::BoostAODE();
|
||||||
|
auto clf_no_alpha = bayesnet::BoostAODE();
|
||||||
|
auto raw = RawDatasets("diabetes", true);
|
||||||
|
clf_alpha.setHyperparameters({
|
||||||
|
{"alpha_block", true},
|
||||||
|
});
|
||||||
|
clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
|
||||||
|
clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
|
||||||
|
auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test);
|
||||||
|
auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test);
|
||||||
|
REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon));
|
||||||
|
REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon));
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user