Complete implementation with tests
This commit is contained in:
@@ -158,4 +158,47 @@ TEST_CASE("TEST MinFill method", "[Node]")
|
||||
REQUIRE(node_2.minFill() == 6);
|
||||
REQUIRE(node_3.minFill() == 3);
|
||||
REQUIRE(node_4.minFill() == 1);
|
||||
}
|
||||
TEST_CASE("Test operator =", "[Node]")
|
||||
{
|
||||
// Generate a test to test the operator = of the Node class
|
||||
// Create a node with 3 parents and 2 children
|
||||
auto node = bayesnet::Node("N1");
|
||||
auto parent_1 = bayesnet::Node("P1");
|
||||
parent_1.setNumStates(3);
|
||||
auto child_1 = bayesnet::Node("H1");
|
||||
child_1.setNumStates(2);
|
||||
node.addParent(&parent_1);
|
||||
node.addChild(&child_1);
|
||||
// Create a cpt in the node using computeCPT
|
||||
auto dataset = torch::tensor({ {1, 0, 0, 1}, {0, 1, 2, 1}, {0, 1, 1, 0} });
|
||||
auto states = std::vector<int>({ 2, 3, 3 });
|
||||
auto features = std::vector<std::string>{ "N1", "P1", "H1" };
|
||||
auto className = std::string("Class");
|
||||
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 }, torch::kDouble);
|
||||
node.setNumStates(2);
|
||||
node.computeCPT(dataset, features, 0.0, weights);
|
||||
// Get the cpt of the node
|
||||
auto cpt = node.getCPT();
|
||||
// Check that the cpt is not empty
|
||||
REQUIRE(cpt.numel() > 0);
|
||||
// Check that the cpt has the correct dimensions
|
||||
auto dimensions = cpt.sizes();
|
||||
REQUIRE(dimensions.size() == 2);
|
||||
REQUIRE(dimensions[0] == 2); // Number of states of the node
|
||||
REQUIRE(dimensions[1] == 3); // Number of states of the first parent
|
||||
// Create a copy of the node
|
||||
auto node_copy = node;
|
||||
// Check that the copy has not any parents or children
|
||||
auto parents = node_copy.getParents();
|
||||
auto children = node_copy.getChildren();
|
||||
REQUIRE(parents.size() == 0);
|
||||
REQUIRE(children.size() == 0);
|
||||
// Check that the copy has the same name
|
||||
REQUIRE(node_copy.getName() == "N1");
|
||||
// Check that the copy has the same cpt
|
||||
auto cpt_copy = node_copy.getCPT();
|
||||
REQUIRE(cpt_copy.equal(cpt));
|
||||
// Check that the copy has the same number of states
|
||||
REQUIRE(node_copy.getNumStates() == node.getNumStates());
|
||||
}
|
Reference in New Issue
Block a user