bisection proposal #24

Merged
rmontanana merged 23 commits from bisection into main 2024-04-08 14:29:26 +00:00
2 changed files with 22 additions and 21 deletions
Showing only changes of commit e55365c41c - Show all commits

View File

@ -27,7 +27,8 @@
"name": "Linux", "name": "Linux",
"includePath": [ "includePath": [
"/home/rmontanana/Code/BayesNet/**", "/home/rmontanana/Code/BayesNet/**",
"/home/rmontanana/Code/libtorch/include/torch/csrc/api/include/" "/home/rmontanana/Code/libtorch/include/torch/csrc/api/include/",
"/home/rmontanana/Code/BayesNet/lib/"
], ],
"defines": [], "defines": [],
"cStandard": "c17", "cStandard": "c17",

View File

@ -98,26 +98,26 @@ TEST_CASE("BoostAODE feature_select CFS", "[Models]")
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS"); REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
REQUIRE(clf.getNotes()[1] == "Number of models: 9"); REQUIRE(clf.getNotes()[1] == "Number of models: 9");
} }
// TEST_CASE("BoostAODE test used features in train note and score", "[BayesNet]") TEST_CASE("BoostAODE test used features in train note and score", "[Models]")
// { {
// auto raw = RawDatasets("diabetes", true); auto raw = RawDatasets("diabetes", true);
// auto clf = bayesnet::BoostAODE(true); auto clf = bayesnet::BoostAODE(true);
// clf.setHyperparameters({ clf.setHyperparameters({
// {"order", "asc"}, {"order", "asc"},
// {"convergence", true}, {"convergence", true},
// {"select_features","CFS"}, {"select_features","CFS"},
// }); });
// clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
// REQUIRE(clf.getNumberOfNodes() == 72); REQUIRE(clf.getNumberOfNodes() == 72);
// REQUIRE(clf.getNumberOfEdges() == 120); REQUIRE(clf.getNumberOfEdges() == 120);
// REQUIRE(clf.getNotes().size() == 2); REQUIRE(clf.getNotes().size() == 2);
// REQUIRE(clf.getNotes()[0] == "Used features in initialization: 7 of 8 with CFS"); REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 8 with CFS");
// REQUIRE(clf.getNotes()[1] == "Number of models: 8"); REQUIRE(clf.getNotes()[1] == "Number of models: 8");
// auto score = clf.score(raw.Xv, raw.yv); auto score = clf.score(raw.Xv, raw.yv);
// auto scoret = clf.score(raw.Xt, raw.yt); auto scoret = clf.score(raw.Xt, raw.yt);
// REQUIRE(score == Catch::Approx(0.82031).epsilon(raw.epsilon)); REQUIRE(score == Catch::Approx(0.82031).epsilon(raw.epsilon));
// REQUIRE(scoret == Catch::Approx(0.82031).epsilon(raw.epsilon)); REQUIRE(scoret == Catch::Approx(0.82031).epsilon(raw.epsilon));
// } }
TEST_CASE("Model predict_proba", "[Models]") TEST_CASE("Model predict_proba", "[Models]")
{ {
std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting"); std::string model = GENERATE("TAN", "SPODE", "BoostAODEproba", "BoostAODEvoting");