From ab86dae90d369c053195f235a6118c9e03ea87dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 1 Jun 2025 14:55:31 +0200 Subject: [PATCH] Add tests for Ld models predict_proba --- README.md | 2 +- tests/TestBayesModels.cc | 107 ++++++++++++++++++++++++++++----------- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index e7372ab..0115ec4 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![Reliability Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=reliability_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/Doctorado-ML/BayesNet) ![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es&logo=gitea) -[![Coverage Badge](https://img.shields.io/badge/Coverage-99,1%25-green)](https://gitea.rmontanana.es/rmontanana/BayesNet) +[![Coverage Badge](https://img.shields.io/badge/Coverage-99,2%25-green)](https://gitea.rmontanana.es/rmontanana/BayesNet) [![DOI](https://zenodo.org/badge/667782806.svg)](https://doi.org/10.5281/zenodo.14210344) Bayesian Network Classifiers library diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index 7a80cb8..450a5cb 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -152,7 +152,7 @@ TEST_CASE("Get num features & num edges", "[Models]") TEST_CASE("Model predict_proba", "[Models]") { - std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting"); + std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting", "TANLd", "SPODELd", "KDBLd"); auto res_prob_tan = std::vector>({ {0.00375671, 0.994457, 0.00178621}, {0.00137462, 0.992734, 0.00589123}, {0.00137462, 0.992734, 0.00589123}, @@ -180,50 +180,99 @@ TEST_CASE("Model predict_proba", "[Models]") {0.0284828, 0.770524, 0.200993}, {0.0213182, 0.857189, 0.121493}, {0.00868436, 0.949494, 0.0418215} }); + auto res_prob_tanld = std::vector>({ {0.000544493, 0.995796, 0.00365992 }, + {0.000908092, 0.997268, 0.00182429 }, + {0.000908092, 0.997268, 0.00182429 }, + {0.000908092, 0.997268, 0.00182429 }, + {0.00228423, 0.994645, 0.00307078 }, + {0.00120539, 0.0666788, 0.932116 }, + {0.00361847, 0.979203, 0.017179 }, + {0.00483293, 0.985326, 0.00984064 }, + {0.000595606, 0.9977, 0.00170441 } }); + auto res_prob_spodeld = std::vector>({ {0.000908024, 0.993742, 0.00535024 }, + {0.00187726, 0.99167, 0.00645308 }, + {0.00187726, 0.99167, 0.00645308 }, + {0.00187726, 0.99167, 0.00645308 }, + {0.00287539, 0.993736, 0.00338846 }, + {0.00294402, 0.268495, 0.728561 }, + {0.0132381, 0.873282, 0.113479 }, + {0.0159412, 0.969228, 0.0148308 }, + {0.00203487, 0.989762, 0.00820356 } }); + auto res_prob_kdbld = std::vector>({ {0.000738981, 0.997208, 0.00205272 }, + {0.00087708, 0.996687, 0.00243633 }, + {0.00087708, 0.996687, 0.00243633 }, + {0.00087708, 0.996687, 0.00243633 }, + {0.000738981, 0.997208, 0.00205272 }, + {0.00512442, 0.0455504, 0.949325 }, + {0.0023632, 0.976631, 0.0210063 }, + {0.00189194, 0.992853, 0.00525538 }, + {0.00189194, 0.992853, 0.00525538, } }); auto res_prob_voting = std::vector>( { {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0} }); std::map>> res_prob{ {"TAN", res_prob_tan}, {"SPODE", res_prob_spode}, {"BoostAODEproba", res_prob_baode}, - {"BoostAODEvoting", res_prob_voting} }; + {"BoostAODEvoting", res_prob_voting}, + {"TANLd", res_prob_tanld}, + {"SPODELd", res_prob_spodeld}, + {"KDBLd", res_prob_kdbld} }; std::map models{ {"TAN", new bayesnet::TAN()}, {"SPODE", new bayesnet::SPODE(0)}, {"BoostAODEproba", new bayesnet::BoostAODE(false)}, - {"BoostAODEvoting", new bayesnet::BoostAODE(true)} }; + {"BoostAODEvoting", new bayesnet::BoostAODE(true)}, + {"TANLd", new bayesnet::TANLd()}, + {"SPODELd", new bayesnet::SPODELd(0)}, + {"KDBLd", new bayesnet::KDBLd(2)} }; int init_index = 78; - auto raw = RawDatasets("iris", true); SECTION("Test " + model + " predict_proba") { + auto ld_model = model.substr(model.length() - 2) == "Ld"; + auto discretize = !ld_model; + auto raw = RawDatasets("iris", discretize); auto clf = models[model]; - clf->fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing); - auto y_pred_proba = clf->predict_proba(raw.Xv); + clf->fit(raw.Xt, raw.yt, raw.features, raw.className, raw.states, raw.smoothing); auto yt_pred_proba = clf->predict_proba(raw.Xt); - auto y_pred = clf->predict(raw.Xv); auto yt_pred = clf->predict(raw.Xt); - REQUIRE(y_pred.size() == yt_pred.size(0)); - REQUIRE(y_pred.size() == y_pred_proba.size()); - REQUIRE(y_pred.size() == yt_pred_proba.size(0)); - REQUIRE(y_pred.size() == raw.yv.size()); - REQUIRE(y_pred_proba[0].size() == 3); - REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size()); - for (int i = 0; i < 9; ++i) { - auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end()); - int predictedClass = distance(y_pred_proba[i].begin(), maxElem); - REQUIRE(predictedClass == y_pred[i]); - // Check predict is coherent with predict_proba - REQUIRE(yt_pred_proba[i].argmax().item() == y_pred[i]); - for (int j = 0; j < yt_pred_proba.size(1); j++) { - REQUIRE(yt_pred_proba[i][j].item() == Catch::Approx(y_pred_proba[i][j]).epsilon(raw.epsilon)); + std::vector y_pred; + std::vector> y_pred_proba; + if (!ld_model) { + y_pred = clf->predict(raw.Xv); + y_pred_proba = clf->predict_proba(raw.Xv); + REQUIRE(y_pred.size() == y_pred_proba.size()); + REQUIRE(y_pred.size() == yt_pred.size(0)); + REQUIRE(y_pred.size() == yt_pred_proba.size(0)); + REQUIRE(y_pred_proba[0].size() == 3); + REQUIRE(y_pred.size() == raw.yv.size()); + REQUIRE(yt_pred_proba.size(1) == y_pred_proba[0].size()); + for (int i = 0; i < 9; ++i) { + auto maxElem = max_element(y_pred_proba[i].begin(), y_pred_proba[i].end()); + int predictedClass = distance(y_pred_proba[i].begin(), maxElem); + REQUIRE(predictedClass == y_pred[i]); + // Check predict is coherent with predict_proba + REQUIRE(yt_pred_proba[i].argmax().item() == y_pred[i]); + for (int j = 0; j < yt_pred_proba.size(1); j++) { + REQUIRE(yt_pred_proba[i][j].item() == Catch::Approx(y_pred_proba[i][j]).epsilon(raw.epsilon)); + } } - } - // Check predict_proba values for vectors and tensors - for (int i = 0; i < 9; i++) { - REQUIRE(y_pred[i] == yt_pred[i].item()); - for (int j = 0; j < 3; j++) { - REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon)); - REQUIRE(res_prob[model][i][j] == - Catch::Approx(yt_pred_proba[i + init_index][j].item()).epsilon(raw.epsilon)); + // Check predict_proba values for vectors and tensors + for (int i = 0; i < 9; i++) { + REQUIRE(y_pred[i] == yt_pred[i].item()); + for (int j = 0; j < 3; j++) { + REQUIRE(res_prob[model][i][j] == Catch::Approx(y_pred_proba[i + init_index][j]).epsilon(raw.epsilon)); + REQUIRE(res_prob[model][i][j] == + Catch::Approx(yt_pred_proba[i + init_index][j].item()).epsilon(raw.epsilon)); + } + } + } else { + // Check predict_proba values for vectors and tensors + auto predictedClasses = yt_pred_proba.argmax(1); + for (int i = 0; i < 9; i++) { + REQUIRE(predictedClasses[i].item() == yt_pred[i].item()); + for (int j = 0; j < 3; j++) { + REQUIRE(res_prob[model][i][j] == + Catch::Approx(yt_pred_proba[i + init_index][j].item()).epsilon(raw.epsilon)); + } } } delete clf;