Add tests for Ld models predict_proba

This commit is contained in:
2025-06-01 14:55:31 +02:00
parent ad72bb355b
commit ab86dae90d
2 changed files with 79 additions and 30 deletions

View File

@@ -152,7 +152,7 @@ TEST_CASE("Get num features & num edges", "[Models]")
TEST_CASE("Model predict_proba", "[Models]")
{
std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting");
std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting", "TANLd", "SPODELd", "KDBLd");
auto res_prob_tan = std::vector<std::vector<double>>({ {0.00375671, 0.994457, 0.00178621},
{0.00137462, 0.992734, 0.00589123},
{0.00137462, 0.992734, 0.00589123},
@@ -180,50 +180,99 @@ TEST_CASE("Model predict_proba", "[Models]")
{0.0284828, 0.770524, 0.200993},
{0.0213182, 0.857189, 0.121493},
{0.00868436, 0.949494, 0.0418215} });
auto res_prob_tanld = std::vector<std::vector<double>>({ {0.000544493, 0.995796, 0.00365992 },
{0.000908092, 0.997268, 0.00182429 },
{0.000908092, 0.997268, 0.00182429 },
{0.000908092, 0.997268, 0.00182429 },
{0.00228423, 0.994645, 0.00307078 },
{0.00120539, 0.0666788, 0.932116 },
{0.00361847, 0.979203, 0.017179 },
{0.00483293, 0.985326, 0.00984064 },
{0.000595606, 0.9977, 0.00170441 } });
auto res_prob_spodeld = std::vector<std::vector<double>>({ {0.000908024, 0.993742, 0.00535024 },
{0.00187726, 0.99167, 0.00645308 },
{0.00187726, 0.99167, 0.00645308 },
{0.00187726, 0.99167, 0.00645308 },
{0.00287539, 0.993736, 0.00338846 },
{0.00294402, 0.268495, 0.728561 },
{0.0132381, 0.873282, 0.113479 },
{0.0159412, 0.969228, 0.0148308 },
{0.00203487, 0.989762, 0.00820356 } });
auto res_prob_kdbld = std::vector<std::vector<double>>({ {0.000738981, 0.997208, 0.00205272 },
{0.00087708, 0.996687, 0.00243633 },
{0.00087708, 0.996687, 0.00243633 },
{0.00087708, 0.996687, 0.00243633 },
{0.000738981, 0.997208, 0.00205272 },
{0.00512442, 0.0455504, 0.949325 },
{0.0023632, 0.976631, 0.0210063 },
{0.00189194, 0.992853, 0.00525538 },
{0.00189194, 0.992853, 0.00525538, } });
auto res_prob_voting = std::vector<std::vector<double>>(
{ {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0} });
std::map<std::string, std::vector<std::vector<double>>> res_prob{ {"TAN", res_prob_tan},
{"SPODE", res_prob_spode},
{"BoostAODEproba", res_prob_baode},
{"BoostAODEvoting", res_prob_voting} };
{"BoostAODEvoting", res_prob_voting},
{"TANLd", res_prob_tanld},
{"SPODELd", res_prob_spodeld},
{"KDBLd", res_prob_kdbld} };
std::map<std::string, bayesnet::BaseClassifier*> models{ {"TAN", new bayesnet::TAN()},
{"SPODE", new bayesnet::SPODE(0)},
{"BoostAODEproba", new bayesnet::BoostAODE(false)},
{"BoostAODEvoting", new bayesnet::BoostAODE(true)} };
{"BoostAODEvoting", new bayesnet::BoostAODE(true)},
{"TANLd", new bayesnet::TANLd()},
{"SPODELd", new bayesnet::SPODELd(0)},
{"KDBLd", new bayesnet::KDBLd(2)} };
int init_index = 78;
auto raw = RawDatasets("iris", true);
SECTION("Test " + model + " predict_proba")
{
auto ld_model = model.substr(model.length() - 2) == "Ld";
auto discretize = !ld_model;
auto raw = RawDatasets("iris", discretize);
auto clf = models[model];
clf->fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto y_pred_proba = clf->predict_proba(raw.Xv);
clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing);
auto yt_pred_proba = clf->predict_proba(raw.Xt);
auto y_pred = clf->predict(raw.Xv);
auto yt_pred = clf->predict(raw.Xt);
REQUIRE(y_pred.size() == yt_pred.size(0));
REQUIRE(y_pred.size() == y_pred_proba.size());
REQUIRE(y_pred.size() == yt_pred_proba.size(0));
REQUIRE(y_pred.size() == raw.yv.size());
REQUIRE(y_pred_proba[0].size() == 3);
REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size());
for (int i = 0; i < 9; ++i) {
auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end());
int predictedClass = distance(y_pred_proba[i].begin(), maxElem);
REQUIRE(predictedClass == y_pred[i]);
// Check predict is coherent with predict_proba
REQUIRE(yt_pred_proba[i].argmax().item<int>() == y_pred[i]);
for (int j = 0; j < yt_pred_proba.size(1); j++) {
REQUIRE(yt_pred_proba[i][j].item<double>() == Catch::Approx(y_pred_proba[i][j]).epsilon(raw.epsilon));
std::vector<int> y_pred;
std::vector<std::vector<double>> y_pred_proba;
if (!ld_model) {
y_pred = clf->predict(raw.Xv);
y_pred_proba = clf->predict_proba(raw.Xv);
REQUIRE(y_pred.size() == y_pred_proba.size());
REQUIRE(y_pred.size() == yt_pred.size(0));
REQUIRE(y_pred.size() == yt_pred_proba.size(0));
REQUIRE(y_pred_proba[0].size() == 3);
REQUIRE(y_pred.size() == raw.yv.size());
REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size());
for (int i = 0; i < 9; ++i) {
auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end());
int predictedClass = distance(y_pred_proba[i].begin(), maxElem);
REQUIRE(predictedClass == y_pred[i]);
// Check predict is coherent with predict_proba
REQUIRE(yt_pred_proba[i].argmax().item<int>() == y_pred[i]);
for (int j = 0; j < yt_pred_proba.size(1); j++) {
REQUIRE(yt_pred_proba[i][j].item<double>() == Catch::Approx(y_pred_proba[i][j]).epsilon(raw.epsilon));
}
}
}
// Check predict_proba values for vectors and tensors
for (int i = 0; i < 9; i++) {
REQUIRE(y_pred[i] == yt_pred[i].item<int>());
for (int j = 0; j < 3; j++) {
REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon));
REQUIRE(res_prob[model][i][j] ==
Catch::Approx(yt_pred_proba[i + init_index][j].item<double>()).epsilon(raw.epsilon));
// Check predict_proba values for vectors and tensors
for (int i = 0; i < 9; i++) {
REQUIRE(y_pred[i] == yt_pred[i].item<int>());
for (int j = 0; j < 3; j++) {
REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon));
REQUIRE(res_prob[model][i][j] ==
Catch::Approx(yt_pred_proba[i + init_index][j].item<double>()).epsilon(raw.epsilon));
}
}
} else {
// Check predict_proba values for vectors and tensors
auto predictedClasses = yt_pred_proba.argmax(1);
for (int i = 0; i < 9; i++) {
REQUIRE(predictedClasses[i].item<int>() == yt_pred[i].item<int>());
for (int j = 0; j < 3; j++) {
REQUIRE(res_prob[model][i][j] ==
Catch::Approx(yt_pred_proba[i + init_index][j].item<double>()).epsilon(raw.epsilon));
}
}
}
delete clf;