Make some boostAODE tests

This commit is contained in:
Ricardo Montañana Gómez 2024-04-08 22:30:55 +02:00
parent a1178554ff
commit fbbed8ad68
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 66 additions and 1810 deletions

View File

@ -5,7 +5,7 @@
![Gitea Release](https://img.shields.io/gitea/v/release/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000)
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/cf3e0ac71d764650b1bf4d8d00d303b1)](https://app.codacy.com/gh/Doctorado-ML/BayesNet/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000&logo=gitea)
![Static Badge](https://img.shields.io/badge/Coverage-94,0%25-green)
![Static Badge](https://img.shields.io/badge/Coverage-95,8%25-green)
Bayesian Network Classifiers using libtorch from scratch

View File

@ -3,6 +3,8 @@
#include <string>
#include "TestUtils.h"
#include "bayesnet/classifiers/TAN.h"
#include "bayesnet/classifiers/KDB.h"
#include "bayesnet/classifiers/KDBLd.h"
TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]")
@ -83,4 +85,20 @@ TEST_CASE("Not fitted model", "[Classifier]")
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message);
REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error);
REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message);
}
TEST_CASE("KDB Graph", "[Classifier]")
{
auto model = bayesnet::KDB(2);
auto raw = RawDatasets("iris", true);
model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto graph = model.graph();
REQUIRE(graph.size() == 15);
}
TEST_CASE("KDBLd Graph", "[Classifier]")
{
auto model = bayesnet::KDBLd(2);
auto raw = RawDatasets("iris", false);
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto graph = model.graph();
REQUIRE(graph.size() == 15);
}

View File

@ -102,5 +102,50 @@ TEST_CASE("Order asc, desc & random", "[BoostAODE]")
}
TEST_CASE("Oddities", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
auto raw = RawDatasets("iris", true);
auto bad_hyper = nlohmann::json{
{ { "order", "duck" } },
{ { "select_features", "duck" } },
{ { "maxTolerance", 0 } },
{ { "maxTolerance", 5 } },
};
for (const auto& hyper : bad_hyper.items()) {
INFO("BoostAODE hyper: " + hyper.value().dump());
REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
}
REQUIRE_THROWS_AS(clf.setHyperparameters({ {"maxTolerance", 0 } }), std::invalid_argument);
auto bad_hyper_fit = nlohmann::json{
{ { "select_features","IWSS" }, { "threshold", -0.01 } },
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
};
for (const auto& hyper : bad_hyper_fit.items()) {
INFO("BoostAODE hyper: " + hyper.value().dump());
clf.setHyperparameters(hyper.value());
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::invalid_argument);
}
}
TEST_CASE("Bisection", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
auto raw = RawDatasets("mfeat-factors", true);
clf.setHyperparameters({
{"bisection", true},
{"maxTolerance", 3},
{"convergence", true},
});
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
REQUIRE(clf.getNumberOfNodes() == 217);
REQUIRE(clf.getNumberOfEdges() == 431);
REQUIRE(clf.getNotes().size() == 3);
REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated");
REQUIRE(clf.getNotes()[1] == "Used features in train: 16 of 216");
REQUIRE(clf.getNotes()[2] == "Number of models: 1");
auto score = clf.score(raw.Xv, raw.yv);
auto scoret = clf.score(raw.Xt, raw.yt);
REQUIRE(score == Catch::Approx(1.0f).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(1.0f).epsilon(raw.epsilon));
}

File diff suppressed because it is too large Load Diff

View File

@ -12,7 +12,7 @@ output = subprocess.check_output(
)
value = float(output.decode("utf-8").strip().replace("%", ""))
if value < 90:
print("Coverage is less than 90%. I won't update the badge.")
print("Coverage is less than 90%. I won't update the badge.")
sys.exit(1)
percentage = output.decode("utf-8").strip().replace(".", ",")
coverage_line = (
@ -27,4 +27,4 @@ with open(readme_file, "w") as f:
f.write(coverage_line + "\n")
else:
f.write(line)
print(f"Coverage updated with value: {percentage}")
print(f"Coverage updated with value: {percentage}")