Complete implementation with tests

This commit is contained in:
2025-07-08 11:42:20 +02:00
parent 2c7352ac38
commit ed380b1494
13 changed files with 255 additions and 170 deletions

View File

@@ -158,4 +158,47 @@ TEST_CASE("TEST MinFill method", "[Node]")
REQUIRE(node_2.minFill() == 6);
REQUIRE(node_3.minFill() == 3);
REQUIRE(node_4.minFill() == 1);
}
TEST_CASE("Test operator =", "[Node]")
{
// Generate a test to test the operator = of the Node class
// Create a node with 3 parents and 2 children
auto node = bayesnet::Node("N1");
auto parent_1 = bayesnet::Node("P1");
parent_1.setNumStates(3);
auto child_1 = bayesnet::Node("H1");
child_1.setNumStates(2);
node.addParent(&parent_1);
node.addChild(&child_1);
// Create a cpt in the node using computeCPT
auto dataset = torch::tensor({ {1, 0, 0, 1}, {0, 1, 2, 1}, {0, 1, 1, 0} });
auto states = std::vector<int>({ 2, 3, 3 });
auto features = std::vector<std::string>{ "N1", "P1", "H1" };
auto className = std::string("Class");
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 }, torch::kDouble);
node.setNumStates(2);
node.computeCPT(dataset, features, 0.0, weights);
// Get the cpt of the node
auto cpt = node.getCPT();
// Check that the cpt is not empty
REQUIRE(cpt.numel() > 0);
// Check that the cpt has the correct dimensions
auto dimensions = cpt.sizes();
REQUIRE(dimensions.size() == 2);
REQUIRE(dimensions[0] == 2); // Number of states of the node
REQUIRE(dimensions[1] == 3); // Number of states of the first parent
// Create a copy of the node
auto node_copy = node;
// Check that the copy has not any parents or children
auto parents = node_copy.getParents();
auto children = node_copy.getChildren();
REQUIRE(parents.size() == 0);
REQUIRE(children.size() == 0);
// Check that the copy has the same name
REQUIRE(node_copy.getName() == "N1");
// Check that the copy has the same cpt
auto cpt_copy = node_copy.getCPT();
REQUIRE(cpt_copy.equal(cpt));
// Check that the copy has the same number of states
REQUIRE(node_copy.getNumStates() == node.getNumStates());
}