Complete implementation with tests
This commit is contained in:
@@ -345,12 +345,12 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
auto net1 = bayesnet::Network();
|
||||
buildModel(net1, raw.features, raw.className);
|
||||
net1.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
|
||||
// Create empty network and assign
|
||||
auto net2 = bayesnet::Network();
|
||||
net2.addNode("TempNode"); // Add something to make sure it gets cleared
|
||||
net2 = net1;
|
||||
|
||||
|
||||
// Verify they are equal
|
||||
REQUIRE(net1.getFeatures() == net2.getFeatures());
|
||||
REQUIRE(net1.getEdges() == net2.getEdges());
|
||||
@@ -361,10 +361,10 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
REQUIRE(net1.getSamples().size(0) == net2.getSamples().size(0));
|
||||
REQUIRE(net1.getSamples().size(1) == net2.getSamples().size(1));
|
||||
REQUIRE(net1.getNodes().size() == net2.getNodes().size());
|
||||
|
||||
|
||||
// Verify topology equality
|
||||
REQUIRE(net1 == net2);
|
||||
|
||||
|
||||
// Verify they are separate objects by modifying one
|
||||
net2.initialize();
|
||||
net2.addNode("OnlyInNet2");
|
||||
@@ -376,46 +376,47 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
INFO("Test self assignment");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
|
||||
int original_edges = net.getNumEdges();
|
||||
int original_nodes = net.getNodes().size();
|
||||
|
||||
|
||||
// Self assignment should not corrupt the network
|
||||
net = net;
|
||||
|
||||
auto all_features = raw.features;
|
||||
all_features.push_back(raw.className);
|
||||
REQUIRE(net.getNumEdges() == original_edges);
|
||||
REQUIRE(net.getNodes().size() == original_nodes);
|
||||
REQUIRE(net.getFeatures() == raw.features);
|
||||
REQUIRE(net.getFeatures() == all_features);
|
||||
REQUIRE(net.getClassName() == raw.className);
|
||||
}
|
||||
SECTION("Test operator== topology comparison")
|
||||
{
|
||||
INFO("Test operator== topology comparison");
|
||||
|
||||
|
||||
// Test 1: Two identical networks
|
||||
auto net1 = bayesnet::Network();
|
||||
auto net2 = bayesnet::Network();
|
||||
|
||||
|
||||
net1.addNode("A");
|
||||
net1.addNode("B");
|
||||
net1.addNode("C");
|
||||
net1.addEdge("A", "B");
|
||||
net1.addEdge("B", "C");
|
||||
|
||||
|
||||
net2.addNode("A");
|
||||
net2.addNode("B");
|
||||
net2.addNode("C");
|
||||
net2.addEdge("A", "B");
|
||||
net2.addEdge("B", "C");
|
||||
|
||||
|
||||
REQUIRE(net1 == net2);
|
||||
|
||||
|
||||
// Test 2: Different nodes
|
||||
auto net3 = bayesnet::Network();
|
||||
net3.addNode("A");
|
||||
net3.addNode("D"); // Different node
|
||||
REQUIRE_FALSE(net1 == net3);
|
||||
|
||||
|
||||
// Test 3: Same nodes, different edges
|
||||
auto net4 = bayesnet::Network();
|
||||
net4.addNode("A");
|
||||
@@ -424,12 +425,12 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
net4.addEdge("A", "C"); // Different topology
|
||||
net4.addEdge("B", "C");
|
||||
REQUIRE_FALSE(net1 == net4);
|
||||
|
||||
|
||||
// Test 4: Empty networks
|
||||
auto net5 = bayesnet::Network();
|
||||
auto net6 = bayesnet::Network();
|
||||
REQUIRE(net5 == net6);
|
||||
|
||||
|
||||
// Test 5: Same topology, different edge order
|
||||
auto net7 = bayesnet::Network();
|
||||
net7.addNode("A");
|
||||
@@ -442,35 +443,36 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
SECTION("Test RAII compliance with smart pointers")
|
||||
{
|
||||
INFO("Test RAII compliance with smart pointers");
|
||||
|
||||
|
||||
std::unique_ptr<bayesnet::Network> net1 = std::make_unique<bayesnet::Network>();
|
||||
buildModel(*net1, raw.features, raw.className);
|
||||
net1->fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
|
||||
|
||||
// Test that copy constructor works with smart pointers
|
||||
std::unique_ptr<bayesnet::Network> net2 = std::make_unique<bayesnet::Network>(*net1);
|
||||
|
||||
|
||||
REQUIRE(*net1 == *net2);
|
||||
REQUIRE(net1->getNumEdges() == net2->getNumEdges());
|
||||
REQUIRE(net1->getNodes().size() == net2->getNodes().size());
|
||||
|
||||
|
||||
// Destroy original
|
||||
net1.reset();
|
||||
|
||||
|
||||
// Test predictions still work
|
||||
std::vector<std::vector<int>> test = { {1}, {2}, {0}, {1} };
|
||||
REQUIRE_NOTHROW(net2->predict(test));
|
||||
|
||||
// net2 should still be valid and functional
|
||||
net2->initialize();
|
||||
REQUIRE_NOTHROW(net2->addNode("NewNode"));
|
||||
REQUIRE(net2->getNodes().count("NewNode") == 1);
|
||||
|
||||
// Test predictions still work
|
||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1} };
|
||||
REQUIRE_NOTHROW(net2->predict(test));
|
||||
}
|
||||
SECTION("Test complex topology copy")
|
||||
{
|
||||
INFO("Test complex topology copy");
|
||||
|
||||
|
||||
auto original = bayesnet::Network();
|
||||
|
||||
|
||||
// Create a more complex network
|
||||
original.addNode("Root");
|
||||
original.addNode("Child1");
|
||||
@@ -478,45 +480,45 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
original.addNode("Grandchild1");
|
||||
original.addNode("Grandchild2");
|
||||
original.addNode("Grandchild3");
|
||||
|
||||
|
||||
original.addEdge("Root", "Child1");
|
||||
original.addEdge("Root", "Child2");
|
||||
original.addEdge("Child1", "Grandchild1");
|
||||
original.addEdge("Child1", "Grandchild2");
|
||||
original.addEdge("Child2", "Grandchild3");
|
||||
|
||||
|
||||
// Copy it
|
||||
auto copy = original;
|
||||
|
||||
|
||||
// Verify topology is identical
|
||||
REQUIRE(original == copy);
|
||||
REQUIRE(original.getNodes().size() == copy.getNodes().size());
|
||||
REQUIRE(original.getNumEdges() == copy.getNumEdges());
|
||||
|
||||
|
||||
// Verify edges are properly reconstructed
|
||||
auto originalEdges = original.getEdges();
|
||||
auto copyEdges = copy.getEdges();
|
||||
REQUIRE(originalEdges.size() == copyEdges.size());
|
||||
|
||||
|
||||
// Verify node relationships are properly copied
|
||||
for (const auto& nodePair : original.getNodes()) {
|
||||
const std::string& nodeName = nodePair.first;
|
||||
auto* originalNode = nodePair.second.get();
|
||||
auto* copyNode = copy.getNodes().at(nodeName).get();
|
||||
|
||||
|
||||
REQUIRE(originalNode->getParents().size() == copyNode->getParents().size());
|
||||
REQUIRE(originalNode->getChildren().size() == copyNode->getChildren().size());
|
||||
|
||||
|
||||
// Verify parent names match
|
||||
for (size_t i = 0; i < originalNode->getParents().size(); ++i) {
|
||||
REQUIRE(originalNode->getParents()[i]->getName() ==
|
||||
copyNode->getParents()[i]->getName());
|
||||
REQUIRE(originalNode->getParents()[i]->getName() ==
|
||||
copyNode->getParents()[i]->getName());
|
||||
}
|
||||
|
||||
|
||||
// Verify child names match
|
||||
for (size_t i = 0; i < originalNode->getChildren().size(); ++i) {
|
||||
REQUIRE(originalNode->getChildren()[i]->getName() ==
|
||||
copyNode->getChildren()[i]->getName());
|
||||
REQUIRE(originalNode->getChildren()[i]->getName() ==
|
||||
copyNode->getChildren()[i]->getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user