Refactor coverage report generation

Add some tests to reach 99%
This commit is contained in:
2024-05-06 17:56:00 +02:00
parent 0ec53f405f
commit ced29a2c2e
1091 changed files with 9366 additions and 373937 deletions

View File

@@ -7,7 +7,9 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include <string>
#include <vector>
#include "TestUtils.h"
#include "bayesnet/network/Network.h"
@@ -48,6 +50,73 @@ TEST_CASE("Test Node children and parents", "[Node]")
REQUIRE(parents.size() == 0);
REQUIRE(children.size() == 0);
}
TEST_CASE("Test Node computeCPT", "[Node]")
{
// Generate a test to test the computeCPT method of the Node class
// Create a dataset with 3 features and 4 samples
// The dataset is a 2D tensor with 4 rows and 4 columns
auto dataset = torch::tensor({ {1, 0, 0, 1}, {1, 1, 2, 0}, {0, 1, 2, 1}, {0, 1, 0, 1} });
auto states = std::vector<int>({ 2, 3, 3 });
// Create a vector with the names of the features
auto features = std::vector<std::string>{ "F1", "F2", "F3" };
// Create a vector with the names of the classes
auto className = std::string("Class");
// weights
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 });
std::vector<bayesnet::Node> nodes;
for (int i = 0; i < features.size(); i++) {
auto node = bayesnet::Node(features[i]);
node.setNumStates(states[i]);
nodes.push_back(node);
}
nodes.push_back(bayesnet::Node(className));
nodes[features.size()].setNumStates(2);
for (int i = 0; i < features.size(); i++) {
// Add class node as parent of all feature nodes
nodes[i].addParent(&nodes[features.size()]);
// Node[0] -> Node[1], Node[2]
if (i > 0)
nodes[i].addParent(&nodes[0]);
}
features.push_back(className);
// Compute the conditional probability table
nodes[1].computeCPT(dataset, features, 0.0, weights);
// Get the conditional probability table
auto cpTable = nodes[1].getCPT();
// Get the dimensions of the conditional probability table
auto dimensions = cpTable.sizes();
// Check the dimensions of the conditional probability table
REQUIRE(dimensions.size() == 3);
REQUIRE(dimensions[0] == 3);
REQUIRE(dimensions[1] == 2);
REQUIRE(dimensions[2] == 2);
// Check the values of the conditional probability table
REQUIRE(cpTable[0][0][0].item<float>() == Catch::Approx(0));
REQUIRE(cpTable[0][0][1].item<float>() == Catch::Approx(0));
REQUIRE(cpTable[0][1][0].item<float>() == Catch::Approx(0));
REQUIRE(cpTable[0][1][1].item<float>() == Catch::Approx(1));
REQUIRE(cpTable[1][0][0].item<float>() == Catch::Approx(0));
REQUIRE(cpTable[1][0][1].item<float>() == Catch::Approx(1));
REQUIRE(cpTable[1][1][0].item<float>() == Catch::Approx(1));
REQUIRE(cpTable[1][1][1].item<float>() == Catch::Approx(0));
// Compute evidence
for (auto& node : nodes) {
node.computeCPT(dataset, features, 0.0, weights);
}
auto evidence = std::map<std::string, int>{ { "F1", 0 }, { "F2", 1 }, { "F3", 1 } };
REQUIRE(nodes[3].getFactorValue(evidence) == 0.5);
// Oddities
auto features_back = features;
// Remove a parent from features
features.pop_back();
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature parent Class not found in dataset");
// Remove a feature from features
features = features_back;
features.erase(features.begin());
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature F1 not found in dataset");
}
TEST_CASE("TEST MinFill method", "[Node]")
{
// Generate a test to test the minFill method of the Node class