Continue Test Network

This commit is contained in:
2023-10-09 11:25:30 +02:00
parent e3ae073333
commit 8fdad78a8c
6 changed files with 324 additions and 82 deletions

View File

@@ -5,14 +5,29 @@
#include "TestUtils.h"
#include "Network.h"
void buildModel(bayesnet::Network& net, const vector<string>& features, const string& className)
{
vector<pair<int, int>> network = { {0, 1}, {0, 2}, {1, 3} };
for (const auto& feature : features) {
net.addNode(feature);
}
net.addNode(className);
for (const auto& edge : network) {
net.addEdge(features.at(edge.first), features.at(edge.second));
}
for (const auto& feature : features) {
net.addEdge(className, feature);
}
}
TEST_CASE("Test Bayesian Network", "[BayesNet]")
{
auto raw = RawDatasets("iris", true);
auto net = bayesnet::Network();
SECTION("Test get features")
{
auto net = bayesnet::Network();
net.addNode("A");
net.addNode("B");
REQUIRE(net.getFeatures() == vector<string>{"A", "B"});
@@ -21,7 +36,6 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
}
SECTION("Test get edges")
{
auto net = bayesnet::Network();
net.addNode("A");
net.addNode("B");
net.addNode("C");
@@ -35,7 +49,6 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
}
SECTION("Test getNodes")
{
auto net = bayesnet::Network();
net.addNode("A");
net.addNode("B");
auto& nodes = net.getNodes();
@@ -43,13 +56,119 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
REQUIRE(nodes.count("B") == 1);
}
SECTION("Test fit")
SECTION("Test fit Network")
{
auto net2 = bayesnet::Network();
auto net3 = bayesnet::Network();
net3.initialize();
net2.initialize();
net.initialize();
buildModel(net, raw.featuresv, raw.classNamev);
buildModel(net2, raw.featurest, raw.classNamet);
buildModel(net3, raw.featurest, raw.classNamet);
vector<pair<string, string>> edges = {
{"class", "sepallength"}, {"class", "sepalwidth"}, {"class", "petallength"},
{"class", "petalwidth" }, {"sepallength", "sepalwidth"}, {"sepallength", "petallength"},
{"sepalwidth", "petalwidth"}
};
REQUIRE(net.getEdges() == edges);
REQUIRE(net2.getEdges() == edges);
REQUIRE(net3.getEdges() == edges);
vector<string> features = { "sepallength", "sepalwidth", "petallength", "petalwidth", "class" };
REQUIRE(net.getFeatures() == features);
REQUIRE(net2.getFeatures() == features);
REQUIRE(net3.getFeatures() == features);
auto& nodes = net.getNodes();
auto& nodes2 = net2.getNodes();
auto& nodes3 = net3.getNodes();
// Check Nodes parents & children
for (const auto& feature : features) {
// Parents
vector<string> parents, parents2, parents3, children, children2, children3;
auto nodeParents = nodes[feature]->getParents();
auto nodeParents2 = nodes2[feature]->getParents();
auto nodeParents3 = nodes3[feature]->getParents();
transform(nodeParents.begin(), nodeParents.end(), back_inserter(parents), [](const auto& p) { return p->getName(); });
transform(nodeParents2.begin(), nodeParents2.end(), back_inserter(parents2), [](const auto& p) { return p->getName(); });
transform(nodeParents3.begin(), nodeParents3.end(), back_inserter(parents3), [](const auto& p) { return p->getName(); });
REQUIRE(parents == parents2);
REQUIRE(parents == parents3);
// Children
auto nodeChildren = nodes[feature]->getChildren();
auto nodeChildren2 = nodes2[feature]->getChildren();
auto nodeChildren3 = nodes2[feature]->getChildren();
transform(nodeChildren.begin(), nodeChildren.end(), back_inserter(children), [](const auto& p) { return p->getName(); });
transform(nodeChildren2.begin(), nodeChildren2.end(), back_inserter(children2), [](const auto& p) { return p->getName(); });
transform(nodeChildren3.begin(), nodeChildren3.end(), back_inserter(children3), [](const auto& p) { return p->getName(); });
REQUIRE(children == children2);
REQUIRE(children == children3);
}
// Fit networks
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
net2.fit(raw.dataset, raw.weights, raw.featurest, raw.classNamet, raw.statest);
net3.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest);
REQUIRE(net.getStates() == net2.getStates());
REQUIRE(net.getStates() == net3.getStates());
// Check Conditional Probabilities tables
for (int i = 0; i < features.size(); ++i) {
auto feature = features.at(i);
for (const auto& feature : features) {
auto cpt = nodes[feature]->getCPT();
auto cpt2 = nodes2[feature]->getCPT();
auto cpt3 = nodes3[feature]->getCPT();
REQUIRE(cpt.equal(cpt2));
REQUIRE(cpt.equal(cpt3));
}
}
}
SECTION("Test show")
{
auto net = bayesnet::Network();
// net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
net.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest);
REQUIRE(net.getClassName() == "class");
net.addNode("A");
net.addNode("B");
net.addNode("C");
net.addEdge("A", "B");
net.addEdge("A", "C");
auto str = net.show();
REQUIRE(str.size() == 3);
REQUIRE(str[0] == "A -> B, C, ");
REQUIRE(str[1] == "B -> ");
REQUIRE(str[2] == "C -> ");
}
SECTION("Test topological_sort")
{
auto net = bayesnet::Network();
net.addNode("A");
net.addNode("B");
net.addNode("C");
net.addEdge("A", "B");
net.addEdge("A", "C");
auto sorted = net.topological_sort();
REQUIRE(sorted.size() == 3);
REQUIRE(sorted[0] == "A");
bool result = sorted[1] == "B" && sorted[2] == "C";
REQUIRE(result);
}
SECTION("Test graph")
{
auto net = bayesnet::Network();
net.addNode("A");
net.addNode("B");
net.addNode("C");
net.addEdge("A", "B");
net.addEdge("A", "C");
auto str = net.graph("Test Graph");
REQUIRE(str.size() == 7);
cout << str << endl;
REQUIRE(str[0] == "digraph BayesNet {\nlabel=<BayesNet Test Graph>\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n");
REQUIRE(str[1] == "A [shape=circle] \n");
REQUIRE(str[2] == "A -> B");
REQUIRE(str[3] == "A -> C");
REQUIRE(str[4] == "B [shape=circle] \n");
REQUIRE(str[5] == "C [shape=circle] \n");
REQUIRE(str[6] == "}\n");
}
// SECTION("Test predict")
// {
@@ -81,34 +200,8 @@ TEST_CASE("Test Bayesian Network", "[BayesNet]")
// REQUIRE(score == Catch::Approx();
// }
// SECTION("Test topological_sort")
// {
// auto net = bayesnet::Network();
// net.addNode("A");
// net.addNode("B");
// net.addNode("C");
// net.addEdge("A", "B");
// net.addEdge("A", "C");
// auto sorted = net.topological_sort();
// REQUIRE(sorted.size() == 3);
// REQUIRE(sorted[0] == "A");
// REQUIRE((sorted[1] == "B" && sorted[2] == "C") || (sorted[1] == "C" && sorted[2] == "B"));
// }
// SECTION("Test show")
// {
// auto net = bayesnet::Network();
// net.addNode("A");
// net.addNode("B");
// net.addNode("C");
// net.addEdge("A", "B");
// net.addEdge("A", "C");
// auto str = net.show();
// REQUIRE(str.size() == 3);
// REQUIRE(str[0] == "A");
// REQUIRE(str[1] == "B -> C");
// REQUIRE(str[2] == "C");
// }
//
//
// SECTION("Test graph")
// {