Add notes to Classifier & Changelog

This commit is contained in:
2024-02-12 10:58:20 +01:00
parent 03f8b8653b
commit f3b8150e2c
6 changed files with 95 additions and 99 deletions

View File

@@ -25,6 +25,7 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
auto raw = RawDatasets("iris", true);
auto net = bayesnet::Network();
double threshold = 1e-4;
SECTION("Test get features")
{
@@ -167,97 +168,44 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
REQUIRE(str[5] == "C [shape=circle] \n");
REQUIRE(str[6] == "}\n");
}
// SECTION("Test predict")
// {
// auto net = bayesnet::Network();
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
// std::vector<std::vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
// std::vector<int> y_test = { 0, 1, 1, 0, 2 };
// auto y_pred = net.predict(test);
// REQUIRE(y_pred == y_test);
// }
// SECTION("Test predict_proba")
// {
// auto net = bayesnet::Network();
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
// std::vector<std::vector<int>> test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
// auto y_test = { 0, 1, 1, 0, 2 };
// auto y_pred = net.predict(test);
// REQUIRE(y_pred == y_test);
// }
}
// SECTION("Test score")
// {
// auto net = bayesnet::Network();
// net.fit(Xd, y, weights, features, className, states);
// auto test = { {1, 2, 0, 1}, {0, 1, 2, 0}, {1, 1, 1, 1}, {0, 0, 0, 0}, {2, 2, 2, 2} };
// auto score = net.score(X, y);
// REQUIRE(score == Catch::Approx();
// }
//
//
// SECTION("Test graph")
// {
// auto net = bayesnet::Network();
// net.addNode("A");
// net.addNode("B");
// net.addNode("C");
// net.addEdge("A", "B");
// net.addEdge("A", "C");
// auto str = net.graph("Test Graph");
// REQUIRE(str.size() == 6);
// REQUIRE(str[0] == "digraph \"Test Graph\" {");
// REQUIRE(str[1] == " A -> B;");
// REQUIRE(str[2] == " A -> C;");
// REQUIRE(str[3] == " B [shape=ellipse];");
// REQUIRE(str[4] == " C [shape=ellipse];");
// REQUIRE(str[5] == "}");
// }
// SECTION("Test initialize")
// {
// auto net = bayesnet::Network();
// net.addNode("A");
// net.addNode("B");
// net.addNode("C");
// net.addEdge("A", "B");
// net.addEdge("A", "C");
// net.initialize();
// REQUIRE(net.getNodes().size() == 0);
// REQUIRE(net.getEdges().size() == 0);
// REQUIRE(net.getFeatures().size() == 0);
// REQUIRE(net.getClassNumStates() == 0);
// REQUIRE(net.getClassName().empty());
// REQUIRE(net.getStates() == 0);
// REQUIRE(net.getSamples().numel() == 0);
// }
// SECTION("Test dump_cpt")
// {
// auto net = bayesnet::Network();
// net.addNode("A");
// net.addNode("B");
// net.addNode("C");
// net.addEdge("A", "B");
// net.addEdge("A", "C");
// net.setClassName("C");
// net.setStates({ {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
// net.fit({ {0, 0}, {0, 1}, {1, 0}, {1, 1} }, { 0, 1, 1, 2 }, {}, { "A", "B" }, "C", { {"A", {0, 1}}, {"B", {0, 1}}, {"C", {0, 1, 2}} });
// net.dump_cpt();
// // TODO: Check that the file was created and contains the expected data
// }
// SECTION("Test version")
// {
// auto net = bayesnet::Network();
// REQUIRE(net.version() == "0.2.0");
// }
// }
// }
SECTION("Test predict")
{
auto net = bayesnet::Network();
buildModel(net, raw.featuresv, raw.classNamev);
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
std::vector<int> y_test = { 2, 2, 0, 2, 1 };
auto y_pred = net.predict(test);
REQUIRE(y_pred == y_test);
}
SECTION("Test predict_proba")
{
auto net = bayesnet::Network();
buildModel(net, raw.featuresv, raw.classNamev);
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
std::vector<std::vector<double>> y_test = {
{0.450237, 0.0866621, 0.463101},
{0.244443, 0.0925922, 0.662964},
{0.913441, 0.0125857, 0.0739732},
{0.450237, 0.0866621, 0.463101},
{0.0135226, 0.971726, 0.0147519}
};
auto y_pred = net.predict_proba(test);
REQUIRE(y_pred.size() == 5);
REQUIRE(y_pred[0].size() == 3);
for (int i = 0; i < y_pred.size(); ++i) {
for (int j = 0; j < y_pred[i].size(); ++j) {
REQUIRE(y_pred[i][j] == Catch::Approx(y_test[i][j]).margin(threshold));
}
}
}
SECTION("Test score")
{
auto net = bayesnet::Network();
buildModel(net, raw.featuresv, raw.classNamev);
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = net.score(raw.Xv, raw.yv);
REQUIRE(score == Catch::Approx(0.97333333).margin(threshold));
}
}