159 lines
6.0 KiB
C++
159 lines
6.0 KiB
C++
// ***************************************************************
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
// SPDX-FileType: SOURCE
|
|
// SPDX-License-Identifier: MIT
|
|
// ***************************************************************
|
|
|
|
#include <catch2/catch_test_macros.hpp>
|
|
#include <catch2/catch_approx.hpp>
|
|
#include <catch2/generators/catch_generators.hpp>
|
|
#include <catch2/matchers/catch_matchers.hpp>
|
|
#include <string>
|
|
#include <vector>
|
|
#include "TestUtils.h"
|
|
#include "bayesnet/network/Network.h"
|
|
|
|
|
|
|
|
TEST_CASE("Test Node children and parents", "[Node]")
|
|
{
|
|
auto node = bayesnet::Node("Node");
|
|
REQUIRE(node.getName() == "Node");
|
|
auto parent_1 = bayesnet::Node("P1");
|
|
auto parent_2 = bayesnet::Node("P2");
|
|
auto child_1 = bayesnet::Node("H1");
|
|
auto child_2 = bayesnet::Node("H2");
|
|
auto child_3 = bayesnet::Node("H3");
|
|
node.addParent(&parent_1);
|
|
node.addParent(&parent_2);
|
|
node.addChild(&child_1);
|
|
node.addChild(&child_2);
|
|
node.addChild(&child_3);
|
|
auto parents = node.getParents();
|
|
auto children = node.getChildren();
|
|
REQUIRE(parents.size() == 2);
|
|
REQUIRE(children.size() == 3);
|
|
REQUIRE(parents[0]->getName() == "P1");
|
|
REQUIRE(parents[1]->getName() == "P2");
|
|
REQUIRE(children[0]->getName() == "H1");
|
|
REQUIRE(children[1]->getName() == "H2");
|
|
REQUIRE(children[2]->getName() == "H3");
|
|
node.removeParent(&parent_1);
|
|
node.removeChild(&child_1);
|
|
parents = node.getParents();
|
|
children = node.getChildren();
|
|
REQUIRE(parents.size() == 1);
|
|
REQUIRE(children.size() == 2);
|
|
node.clear();
|
|
parents = node.getParents();
|
|
children = node.getChildren();
|
|
REQUIRE(parents.size() == 0);
|
|
REQUIRE(children.size() == 0);
|
|
}
|
|
TEST_CASE("Test Node computeCPT", "[Node]")
|
|
{
|
|
// Generate a test to test the computeCPT method of the Node class
|
|
// Create a dataset with 3 features and 4 samples
|
|
// The dataset is a 2D tensor with 4 rows and 4 columns
|
|
auto dataset = torch::tensor({ {1, 0, 0, 1}, {1, 1, 2, 0}, {0, 1, 2, 1}, {0, 1, 0, 1} });
|
|
auto states = std::vector<int>({ 2, 3, 3 });
|
|
// Create a vector with the names of the features
|
|
auto features = std::vector<std::string>{ "F1", "F2", "F3" };
|
|
// Create a vector with the names of the classes
|
|
auto className = std::string("Class");
|
|
// weights
|
|
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 });
|
|
std::vector<bayesnet::Node> nodes;
|
|
for (int i = 0; i < features.size(); i++) {
|
|
auto node = bayesnet::Node(features[i]);
|
|
node.setNumStates(states[i]);
|
|
nodes.push_back(node);
|
|
}
|
|
nodes.push_back(bayesnet::Node(className));
|
|
nodes[features.size()].setNumStates(2);
|
|
for (int i = 0; i < features.size(); i++) {
|
|
// Add class node as parent of all feature nodes
|
|
nodes[i].addParent(&nodes[features.size()]);
|
|
// Node[0] -> Node[1], Node[2]
|
|
if (i > 0)
|
|
nodes[i].addParent(&nodes[0]);
|
|
}
|
|
features.push_back(className);
|
|
// Compute the conditional probability table
|
|
nodes[1].computeCPT(dataset, features, 0.0, weights);
|
|
// Get the conditional probability table
|
|
auto cpTable = nodes[1].getCPT();
|
|
// Get the dimensions of the conditional probability table
|
|
auto dimensions = cpTable.sizes();
|
|
// Check the dimensions of the conditional probability table
|
|
REQUIRE(dimensions.size() == 3);
|
|
REQUIRE(dimensions[0] == 3);
|
|
REQUIRE(dimensions[1] == 2);
|
|
REQUIRE(dimensions[2] == 2);
|
|
// Check the values of the conditional probability table
|
|
REQUIRE(cpTable[0][0][0].item<float>() == Catch::Approx(0));
|
|
REQUIRE(cpTable[0][0][1].item<float>() == Catch::Approx(0));
|
|
REQUIRE(cpTable[0][1][0].item<float>() == Catch::Approx(0));
|
|
REQUIRE(cpTable[0][1][1].item<float>() == Catch::Approx(1));
|
|
REQUIRE(cpTable[1][0][0].item<float>() == Catch::Approx(0));
|
|
REQUIRE(cpTable[1][0][1].item<float>() == Catch::Approx(1));
|
|
REQUIRE(cpTable[1][1][0].item<float>() == Catch::Approx(1));
|
|
REQUIRE(cpTable[1][1][1].item<float>() == Catch::Approx(0));
|
|
// Compute evidence
|
|
for (auto& node : nodes) {
|
|
node.computeCPT(dataset, features, 0.0, weights);
|
|
}
|
|
auto evidence = std::map<std::string, int>{ { "F1", 0 }, { "F2", 1 }, { "F3", 1 } };
|
|
REQUIRE(nodes[3].getFactorValue(evidence) == 0.5);
|
|
// Oddities
|
|
auto features_back = features;
|
|
// Remove a parent from features
|
|
features.pop_back();
|
|
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
|
|
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature parent Class not found in dataset");
|
|
// Remove a feature from features
|
|
features = features_back;
|
|
features.erase(features.begin());
|
|
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
|
|
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature F1 not found in dataset");
|
|
}
|
|
TEST_CASE("TEST MinFill method", "[Node]")
|
|
{
|
|
// Generate a test to test the minFill method of the Node class
|
|
// Create a graph with 5 nodes
|
|
// The graph is a chain with some additional edges
|
|
// 0 -> 1,2,3
|
|
// 1 -> 2,4
|
|
// 2 -> 3
|
|
// 3 -> 4
|
|
auto node_0 = bayesnet::Node("0");
|
|
auto node_1 = bayesnet::Node("1");
|
|
auto node_2 = bayesnet::Node("2");
|
|
auto node_3 = bayesnet::Node("3");
|
|
auto node_4 = bayesnet::Node("4");
|
|
// node 0
|
|
node_0.addChild(&node_1);
|
|
node_0.addChild(&node_2);
|
|
node_0.addChild(&node_3);
|
|
// node 1
|
|
node_1.addChild(&node_2);
|
|
node_1.addChild(&node_4);
|
|
node_1.addParent(&node_0);
|
|
// node 2
|
|
node_2.addChild(&node_3);
|
|
node_2.addChild(&node_4);
|
|
node_2.addParent(&node_0);
|
|
node_2.addParent(&node_1);
|
|
// node 3
|
|
node_3.addChild(&node_4);
|
|
node_3.addParent(&node_0);
|
|
node_3.addParent(&node_2);
|
|
// node 4
|
|
node_4.addParent(&node_1);
|
|
node_4.addParent(&node_3);
|
|
REQUIRE(node_0.minFill() == 3);
|
|
REQUIRE(node_1.minFill() == 3);
|
|
REQUIRE(node_2.minFill() == 6);
|
|
REQUIRE(node_3.minFill() == 3);
|
|
REQUIRE(node_4.minFill() == 1);
|
|
} |