Refactor coverage report generation
Add some tests to reach 99%
This commit is contained in:
@@ -7,7 +7,9 @@
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/matchers/catch_matchers.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "TestUtils.h"
|
||||
#include "bayesnet/network/Network.h"
|
||||
|
||||
@@ -48,6 +50,73 @@ TEST_CASE("Test Node children and parents", "[Node]")
|
||||
REQUIRE(parents.size() == 0);
|
||||
REQUIRE(children.size() == 0);
|
||||
}
|
||||
TEST_CASE("Test Node computeCPT", "[Node]")
|
||||
{
|
||||
// Generate a test to test the computeCPT method of the Node class
|
||||
// Create a dataset with 3 features and 4 samples
|
||||
// The dataset is a 2D tensor with 4 rows and 4 columns
|
||||
auto dataset = torch::tensor({ {1, 0, 0, 1}, {1, 1, 2, 0}, {0, 1, 2, 1}, {0, 1, 0, 1} });
|
||||
auto states = std::vector<int>({ 2, 3, 3 });
|
||||
// Create a vector with the names of the features
|
||||
auto features = std::vector<std::string>{ "F1", "F2", "F3" };
|
||||
// Create a vector with the names of the classes
|
||||
auto className = std::string("Class");
|
||||
// weights
|
||||
auto weights = torch::tensor({ 1.0, 1.0, 1.0, 1.0 });
|
||||
std::vector<bayesnet::Node> nodes;
|
||||
for (int i = 0; i < features.size(); i++) {
|
||||
auto node = bayesnet::Node(features[i]);
|
||||
node.setNumStates(states[i]);
|
||||
nodes.push_back(node);
|
||||
}
|
||||
nodes.push_back(bayesnet::Node(className));
|
||||
nodes[features.size()].setNumStates(2);
|
||||
for (int i = 0; i < features.size(); i++) {
|
||||
// Add class node as parent of all feature nodes
|
||||
nodes[i].addParent(&nodes[features.size()]);
|
||||
// Node[0] -> Node[1], Node[2]
|
||||
if (i > 0)
|
||||
nodes[i].addParent(&nodes[0]);
|
||||
}
|
||||
features.push_back(className);
|
||||
// Compute the conditional probability table
|
||||
nodes[1].computeCPT(dataset, features, 0.0, weights);
|
||||
// Get the conditional probability table
|
||||
auto cpTable = nodes[1].getCPT();
|
||||
// Get the dimensions of the conditional probability table
|
||||
auto dimensions = cpTable.sizes();
|
||||
// Check the dimensions of the conditional probability table
|
||||
REQUIRE(dimensions.size() == 3);
|
||||
REQUIRE(dimensions[0] == 3);
|
||||
REQUIRE(dimensions[1] == 2);
|
||||
REQUIRE(dimensions[2] == 2);
|
||||
// Check the values of the conditional probability table
|
||||
REQUIRE(cpTable[0][0][0].item<float>() == Catch::Approx(0));
|
||||
REQUIRE(cpTable[0][0][1].item<float>() == Catch::Approx(0));
|
||||
REQUIRE(cpTable[0][1][0].item<float>() == Catch::Approx(0));
|
||||
REQUIRE(cpTable[0][1][1].item<float>() == Catch::Approx(1));
|
||||
REQUIRE(cpTable[1][0][0].item<float>() == Catch::Approx(0));
|
||||
REQUIRE(cpTable[1][0][1].item<float>() == Catch::Approx(1));
|
||||
REQUIRE(cpTable[1][1][0].item<float>() == Catch::Approx(1));
|
||||
REQUIRE(cpTable[1][1][1].item<float>() == Catch::Approx(0));
|
||||
// Compute evidence
|
||||
for (auto& node : nodes) {
|
||||
node.computeCPT(dataset, features, 0.0, weights);
|
||||
}
|
||||
auto evidence = std::map<std::string, int>{ { "F1", 0 }, { "F2", 1 }, { "F3", 1 } };
|
||||
REQUIRE(nodes[3].getFactorValue(evidence) == 0.5);
|
||||
// Oddities
|
||||
auto features_back = features;
|
||||
// Remove a parent from features
|
||||
features.pop_back();
|
||||
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature parent Class not found in dataset");
|
||||
// Remove a feature from features
|
||||
features = features_back;
|
||||
features.erase(features.begin());
|
||||
REQUIRE_THROWS_AS(nodes[0].computeCPT(dataset, features, 0.0, weights), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(nodes[0].computeCPT(dataset, features, 0.0, weights), "Feature F1 not found in dataset");
|
||||
}
|
||||
TEST_CASE("TEST MinFill method", "[Node]")
|
||||
{
|
||||
// Generate a test to test the minFill method of the Node class
|
||||
|
Reference in New Issue
Block a user