block_update #26

Merged
rmontanana merged 8 commits from block_update into main 2024-04-15 10:26:51 +00:00
3 changed files with 24 additions and 16 deletions
Showing only changes of commit 1326891d6a - Show all commits

View File

@ -33,6 +33,11 @@ TEST_CASE("Show", "[Ensemble]")
{ {
auto clf = bayesnet::BoostAODE(); auto clf = bayesnet::BoostAODE();
auto raw = RawDatasets("iris", true); auto raw = RawDatasets("iris", true);
clf.setHyperparameters({
{"bisection", false},
{"maxTolerance", 1},
{"convergence", false},
});
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
std::vector<std::string> expected = { std::vector<std::string> expected = {
"class -> sepallength, sepalwidth, petallength, petalwidth, ", "class -> sepallength, sepalwidth, petallength, petalwidth, ",

View File

@ -115,15 +115,15 @@ TEST_CASE("Model predict_proba", "[Models]")
{0.003135, 0.991799, 0.0050661} {0.003135, 0.991799, 0.0050661}
}); });
auto res_prob_baode = std::vector<std::vector<double>>({ auto res_prob_baode = std::vector<std::vector<double>>({
{0.00803291, 0.9676, 0.0243672}, {0.0112349, 0.962274, 0.0264907},
{0.00398714, 0.945126, 0.050887}, {0.00371025, 0.950592, 0.0456973},
{0.00398714, 0.945126, 0.050887}, {0.00371025, 0.950592, 0.0456973},
{0.00398714, 0.945126, 0.050887}, {0.00371025, 0.950592, 0.0456973},
{0.00189227, 0.859575, 0.138533}, {0.00369275, 0.84967, 0.146637},
{0.0118341, 0.442149, 0.546017}, {0.0252205, 0.113564, 0.861215},
{0.0216135, 0.785781, 0.192605}, {0.0284828, 0.770524, 0.200993},
{0.0204803, 0.844276, 0.135244}, {0.0213182, 0.857189, 0.121493},
{0.00576313, 0.961665, 0.0325716}, {0.00868436, 0.949494, 0.0418215}
}); });
auto res_prob_voting = std::vector<std::vector<double>>({ auto res_prob_voting = std::vector<std::vector<double>>({
{0, 1, 0}, {0, 1, 0},
@ -131,8 +131,8 @@ TEST_CASE("Model predict_proba", "[Models]")
{0, 1, 0}, {0, 1, 0},
{0, 1, 0}, {0, 1, 0},
{0, 1, 0}, {0, 1, 0},
{0, 0.447909, 0.552091}, {0, 0, 1},
{0, 0.811482, 0.188517}, {0, 1, 0},
{0, 1, 0}, {0, 1, 0},
{0, 1, 0} {0, 1, 0}
}); });
@ -155,7 +155,7 @@ TEST_CASE("Model predict_proba", "[Models]")
REQUIRE(y_pred.size() == raw.yv.size()); REQUIRE(y_pred.size() == raw.yv.size());
REQUIRE(y_pred_proba[0].size() == 3); REQUIRE(y_pred_proba[0].size() == 3);
REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size()); REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size());
for (int i = 0; i < y_pred_proba.size(); ++i) { for (int i = 0; i < 9; ++i) {
auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end()); auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end());
int predictedClass = distance(y_pred_proba[i].begin(), maxElem); int predictedClass = distance(y_pred_proba[i].begin(), maxElem);
REQUIRE(predictedClass == y_pred[i]); REQUIRE(predictedClass == y_pred[i]);
@ -166,7 +166,7 @@ TEST_CASE("Model predict_proba", "[Models]")
} }
} }
// Check predict_proba values for vectors and tensors // Check predict_proba values for vectors and tensors
for (int i = 0; i < res_prob.size(); i++) { for (int i = 0; i < 9; i++) {
REQUIRE(y_pred[i] == yt_pred[i].item<int>()); REQUIRE(y_pred[i] == yt_pred[i].item<int>());
for (int j = 0; j < 3; j++) { for (int j = 0; j < 3; j++) {
REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon)); REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon));

View File

@ -27,7 +27,7 @@ TEST_CASE("Feature_select IWSS", "[BoostAODE]")
REQUIRE(clf.getNumberOfNodes() == 90); REQUIRE(clf.getNumberOfNodes() == 90);
REQUIRE(clf.getNumberOfEdges() == 153); REQUIRE(clf.getNumberOfEdges() == 153);
REQUIRE(clf.getNotes().size() == 2); REQUIRE(clf.getNotes().size() == 2);
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 5 of 9 with IWSS"); REQUIRE(clf.getNotes()[0] == "Used features in initialization: 4 of 9 with IWSS");
REQUIRE(clf.getNotes()[1] == "Number of models: 9"); REQUIRE(clf.getNotes()[1] == "Number of models: 9");
} }
TEST_CASE("Feature_select FCBF", "[BoostAODE]") TEST_CASE("Feature_select FCBF", "[BoostAODE]")
@ -76,8 +76,8 @@ TEST_CASE("Voting vs proba", "[BoostAODE]")
auto pred_voting = clf.predict_proba(raw.Xv); auto pred_voting = clf.predict_proba(raw.Xv);
REQUIRE(score_proba == Catch::Approx(0.97333).epsilon(raw.epsilon)); REQUIRE(score_proba == Catch::Approx(0.97333).epsilon(raw.epsilon));
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon)); REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon)); REQUIRE(pred_voting[83][2] == Catch::Approx(1.0).epsilon(raw.epsilon));
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon)); REQUIRE(pred_proba[83][2] == Catch::Approx(0.86121525).epsilon(raw.epsilon));
REQUIRE(clf.dump_cpt() == ""); REQUIRE(clf.dump_cpt() == "");
REQUIRE(clf.topological_order() == std::vector<std::string>()); REQUIRE(clf.topological_order() == std::vector<std::string>());
} }
@ -91,6 +91,9 @@ TEST_CASE("Order asc, desc & random", "[BoostAODE]")
auto clf = bayesnet::BoostAODE(); auto clf = bayesnet::BoostAODE();
clf.setHyperparameters({ clf.setHyperparameters({
{"order", order}, {"order", order},
{"bisection", false},
{"maxTolerance", 1},
{"convergence", false},
}); });
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv); auto score = clf.score(raw.Xv, raw.yv);