Fix tests
This commit is contained in:
@@ -149,6 +149,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test show")
|
||||
{
|
||||
INFO("Test show");
|
||||
net.addNode("A");
|
||||
net.addNode("B");
|
||||
net.addNode("C");
|
||||
@@ -162,6 +163,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test topological_sort")
|
||||
{
|
||||
INFO("Test topological sort");
|
||||
net.addNode("A");
|
||||
net.addNode("B");
|
||||
net.addNode("C");
|
||||
@@ -175,6 +177,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test graph")
|
||||
{
|
||||
INFO("Test graph");
|
||||
net.addNode("A");
|
||||
net.addNode("B");
|
||||
net.addNode("C");
|
||||
@@ -192,6 +195,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test predict")
|
||||
{
|
||||
INFO("Test predict");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||
@@ -201,6 +205,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test predict_proba")
|
||||
{
|
||||
INFO("Test predict_proba");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||
@@ -222,6 +227,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test score")
|
||||
{
|
||||
INFO("Test score");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = net.score(raw.Xv, raw.yv);
|
||||
@@ -229,6 +235,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Copy constructor")
|
||||
{
|
||||
INFO("Test copy constructor");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto net2 = bayesnet::Network(net);
|
||||
@@ -252,6 +259,7 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
}
|
||||
SECTION("Test oddities")
|
||||
{
|
||||
INFO("Test oddities");
|
||||
buildModel(net, raw.features, raw.className);
|
||||
// predict without fitting
|
||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||
@@ -270,10 +278,10 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
||||
netx.fit(raw.Xv, raw.yv, raw.weightsv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
std::vector<std::vector<int>> test2 = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1} };
|
||||
auto test_tensor2 = bayesnet::vectorToTensor(test2, false);
|
||||
REQUIRE_THROWS_AS(netx.predict(test2), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(netx.predict(test2), "Sample size (3) does not match the number of features (4)");
|
||||
REQUIRE_THROWS_AS(netx.predict(test_tensor2), std::logic_error);
|
||||
REQUIRE_THROWS_WITH(netx.predict(test_tensor2), "Sample size (3) does not match the number of features (4)");
|
||||
REQUIRE_THROWS_AS(netx.predict(test2), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(netx.predict(test2), "(V) Sample size (3) does not match the number of features (4)");
|
||||
REQUIRE_THROWS_AS(netx.predict(test_tensor2), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(netx.predict(test_tensor2), "(T) Sample size (3) does not match the number of features (4)");
|
||||
// fit with wrong data
|
||||
// Weights
|
||||
auto net2 = bayesnet::Network();
|
||||
@@ -341,15 +349,6 @@ TEST_CASE("Cicle in Network", "[Network]")
|
||||
REQUIRE_THROWS_AS(net.addEdge("C", "A"), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(net.addEdge("C", "A"), "Adding this edge forms a cycle in the graph.");
|
||||
}
|
||||
TEST_CASE("Test max threads constructor", "[Network]")
|
||||
{
|
||||
auto net = bayesnet::Network();
|
||||
REQUIRE(net.getMaxThreads() == 0.95f);
|
||||
auto net2 = bayesnet::Network(4);
|
||||
REQUIRE(net2.getMaxThreads() == 4);
|
||||
auto net3 = bayesnet::Network(1.75);
|
||||
REQUIRE(net3.getMaxThreads() == 1.75);
|
||||
}
|
||||
TEST_CASE("Edges troubles", "[Network]")
|
||||
{
|
||||
auto net = bayesnet::Network();
|
||||
|
Reference in New Issue
Block a user