// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #include #include #include #include #include #include #include "TestUtils.h" #include "bayesnet/network/Network.h" TEST_CASE("Test Node children and parents", "[Node]") { auto node = bayesnet::Node("Node"); REQUIRE(node.getName() == "Node"); auto parent_1 = bayesnet::Node("P1"); auto parent_2 = bayesnet::Node("P2"); auto child_1 = bayesnet::Node("H1"); auto child_2 = bayesnet::Node("H2"); auto child_3 = bayesnet::Node("H3"); node.addParent(&parent_1); node.addParent(&parent_2); node.addChild(&child_1); node.addChild(&child_2); node.addChild(&child_3); auto parents = node.getParents(); auto children = node.getChildren(); REQUIRE(parents.size() == 2); REQUIRE(children.size() == 3); REQUIRE(parents[0]->getName() == "P1"); REQUIRE(parents[1]->getName() == "P2"); REQUIRE(children[0]->getName() == "H1"); REQUIRE(children[1]->getName() == "H2"); REQUIRE(children[2]->getName() == "H3"); node.removeParent(&parent_1); node.removeChild(&child_1); parents = node.getParents(); children = node.getChildren(); REQUIRE(parents.size() == 1); REQUIRE(children.size() == 2); node.clear(); parents = node.getParents(); children = node.getChildren(); REQUIRE(parents.size() == 0); REQUIRE(children.size() == 0); } TEST_CASE("Test Node computeCPT", "[Node]") { // Generate a test to test the computeCPT method of the Node class // Create a dataset with 3 features and 4 samples // The dataset is a 2D tensor with 4 rows and 4 columns auto dataset = torch::tensor({ {1, 0, 0, 1}, {1, 1, 2, 0}, {0, 1, 2, 1}, {0, 1, 0, 1} }); auto states = std::vector({ 2, 3, 3 }); // Create a vector with the names of the features auto features = std::vector{ "F1", "F2", "F3" }; // Create a vector with the names of the classes auto className = std::string("Class"); // weights auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 }); std::vector nodes; for (int i = 0; i < features.size(); i++) { auto node = bayesnet::Node(features[i]); node.setNumStates(states[i]); nodes.push_back(node); } nodes.push_back(bayesnet::Node(className)); nodes[features.size()].setNumStates(2); for (int i = 0; i < features.size(); i++) { // Add class node as parent of all feature nodes nodes[i].addParent(&nodes[features.size()]); // Node[0] -> Node[1], Node[2] if (i > 0) nodes[i].addParent(&nodes[0]); } features.push_back(className); // Compute the conditional probability table nodes[1].computeCPT(dataset, features, 0.0, weights); // Get the conditional probability table auto cpTable = nodes[1].getCPT(); // Get the dimensions of the conditional probability table auto dimensions = cpTable.sizes(); // Check the dimensions of the conditional probability table REQUIRE(dimensions.size() == 3); REQUIRE(dimensions[0] == 3); REQUIRE(dimensions[1] == 2); REQUIRE(dimensions[2] == 2); // Check the values of the conditional probability table REQUIRE(cpTable[0][0][0].item() == Catch::Approx(0)); REQUIRE(cpTable[0][0][1].item() == Catch::Approx(0)); REQUIRE(cpTable[0][1][0].item() == Catch::Approx(0)); REQUIRE(cpTable[0][1][1].item() == Catch::Approx(1)); REQUIRE(cpTable[1][0][0].item() == Catch::Approx(0)); REQUIRE(cpTable[1][0][1].item() == Catch::Approx(1)); REQUIRE(cpTable[1][1][0].item() == Catch::Approx(1)); REQUIRE(cpTable[1][1][1].item() == Catch::Approx(0)); // Compute evidence for (auto& node : nodes) { node.computeCPT(dataset, features, 0.0, weights); } auto evidence = std::map{ { "F1", 0 }, { "F2", 1 }, { "F3", 1 } }; REQUIRE(nodes[3].getFactorValue(evidence) == 0.5); // Oddities auto features_back = features; // Remove a parent from features features.pop_back(); REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error); REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature parent Class not found in dataset"); // Remove a feature from features features = features_back; features.erase(features.begin()); REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error); REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature F1 not found in dataset"); } TEST_CASE("TEST MinFill method", "[Node]") { // Generate a test to test the minFill method of the Node class // Create a graph with 5 nodes // The graph is a chain with some additional edges // 0 -> 1,2,3 // 1 -> 2,4 // 2 -> 3 // 3 -> 4 auto node_0 = bayesnet::Node("0"); auto node_1 = bayesnet::Node("1"); auto node_2 = bayesnet::Node("2"); auto node_3 = bayesnet::Node("3"); auto node_4 = bayesnet::Node("4"); // node 0 node_0.addChild(&node_1); node_0.addChild(&node_2); node_0.addChild(&node_3); // node 1 node_1.addChild(&node_2); node_1.addChild(&node_4); node_1.addParent(&node_0); // node 2 node_2.addChild(&node_3); node_2.addChild(&node_4); node_2.addParent(&node_0); node_2.addParent(&node_1); // node 3 node_3.addChild(&node_4); node_3.addParent(&node_0); node_3.addParent(&node_2); // node 4 node_4.addParent(&node_1); node_4.addParent(&node_3); REQUIRE(node_0.minFill() == 3); REQUIRE(node_1.minFill() == 3); REQUIRE(node_2.minFill() == 6); REQUIRE(node_3.minFill() == 3); REQUIRE(node_4.minFill() == 1); }