Add some more tests to 97% coverage
This commit is contained in:
parent
8eeaa1beee
commit
503ad687dc
@ -5,7 +5,7 @@
|
||||
![Gitea Release](https://img.shields.io/gitea/v/release/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000)
|
||||
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/cf3e0ac71d764650b1bf4d8d00d303b1)](https://app.codacy.com/gh/Doctorado-ML/BayesNet/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
|
||||
![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000&logo=gitea)
|
||||
![Static Badge](https://img.shields.io/badge/Coverage-95,8%25-green)
|
||||
![Static Badge](https://img.shields.io/badge/Coverage-97,1%25-green)
|
||||
|
||||
Bayesian Network Classifiers using libtorch from scratch
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include "Mst.h"
|
||||
@ -45,15 +46,6 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
}
|
||||
void Graph::display_mst()
|
||||
{
|
||||
std::cout << "Edge :" << " Weight" << std::endl;
|
||||
for (int i = 0; i < T.size(); i++) {
|
||||
std::cout << T[i].second.first << " - " << T[i].second.second << " : "
|
||||
<< T[i].first;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void insertElement(std::list<int>& variables, int variable)
|
||||
{
|
||||
|
@ -5,29 +5,28 @@
|
||||
#include <torch/torch.h>
|
||||
namespace bayesnet {
|
||||
class MST {
|
||||
private:
|
||||
torch::Tensor weights;
|
||||
std::vector<std::string> features;
|
||||
int root = 0;
|
||||
public:
|
||||
MST() = default;
|
||||
MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
|
||||
std::vector<std::pair<int, int>> maximumSpanningTree();
|
||||
private:
|
||||
torch::Tensor weights;
|
||||
std::vector<std::string> features;
|
||||
int root = 0;
|
||||
};
|
||||
class Graph {
|
||||
private:
|
||||
int V; // number of nodes in graph
|
||||
std::vector <std::pair<float, std::pair<int, int>>> G; // std::vector for graph
|
||||
std::vector <std::pair<float, std::pair<int, int>>> T; // std::vector for mst
|
||||
std::vector<int> parent;
|
||||
public:
|
||||
explicit Graph(int V);
|
||||
void addEdge(int u, int v, float wt);
|
||||
int find_set(int i);
|
||||
void union_set(int u, int v);
|
||||
void kruskal_algorithm();
|
||||
void display_mst();
|
||||
std::vector <std::pair<float, std::pair<int, int>>> get_mst() { return T; }
|
||||
private:
|
||||
int V; // number of nodes in graph
|
||||
std::vector <std::pair<float, std::pair<int, int>>> G; // std::vector for graph
|
||||
std::vector <std::pair<float, std::pair<int, int>>> T; // std::vector for mst
|
||||
std::vector<int> parent;
|
||||
};
|
||||
}
|
||||
#endif
|
@ -54,6 +54,13 @@ TEST_CASE("Invalid feature name", "[Classifier]")
|
||||
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
|
||||
}
|
||||
TEST_CASE("Invalid hyperparameter", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::KDB(2);
|
||||
auto raw = RawDatasets("iris", true);
|
||||
REQUIRE_THROWS_AS(model.setHyperparameters({ { "alpha", "0.0" } }), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(model.setHyperparameters({ { "alpha", "0.0" } }), "Invalid hyperparameters{\"alpha\":\"0.0\"}");
|
||||
}
|
||||
TEST_CASE("Topological order", "[Classifier]")
|
||||
{
|
||||
auto model = bayesnet::TAN();
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
#include "bayesnet/ensembles/AODE.h"
|
||||
#include "bayesnet/ensembles/AODELd.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
|
||||
@ -73,6 +75,15 @@ TEST_CASE("Graph", "[Ensemble]")
|
||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
auto graph = clf.graph();
|
||||
REQUIRE(graph.size() == 56);
|
||||
auto clf2 = bayesnet::AODE();
|
||||
clf2.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
graph = clf2.graph();
|
||||
REQUIRE(graph.size() == 56);
|
||||
raw = RawDatasets("glass", false);
|
||||
auto clf3 = bayesnet::AODELd();
|
||||
clf3.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
graph = clf3.graph();
|
||||
REQUIRE(graph.size() == 261);
|
||||
}
|
||||
TEST_CASE("Compute ArgMax", "[Ensemble]")
|
||||
{
|
||||
|
@ -14,7 +14,7 @@
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
const std::string ACTUAL_VERSION = "1.0.4";
|
||||
const std::string ACTUAL_VERSION = "1.0.4.1";
|
||||
|
||||
TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
{
|
||||
@ -52,6 +52,7 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
auto score = clf->score(raw.Xt, raw.yt);
|
||||
INFO("Classifier: " + name + " File: " + file_name);
|
||||
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
|
||||
REQUIRE(clf->getStatus() == bayesnet::NORMAL);
|
||||
}
|
||||
}
|
||||
SECTION("Library check version")
|
||||
@ -61,7 +62,7 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[Models]")
|
||||
}
|
||||
delete clf;
|
||||
}
|
||||
TEST_CASE("Models features", "[Models]")
|
||||
TEST_CASE("Models features & Graph", "[Models]")
|
||||
{
|
||||
auto graph = std::vector<std::string>({ "digraph BayesNet {\nlabel=<BayesNet Test>\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n",
|
||||
"class [shape=circle, fontcolor=red, fillcolor=lightblue, style=filled ] \n",
|
||||
@ -70,15 +71,30 @@ TEST_CASE("Models features", "[Models]")
|
||||
"sepallength -> sepalwidth", "sepalwidth [shape=circle] \n", "sepalwidth -> petalwidth", "}\n"
|
||||
}
|
||||
);
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.getNumberOfStates() == 19);
|
||||
REQUIRE(clf.getClassNumStates() == 3);
|
||||
REQUIRE(clf.show() == std::vector<std::string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "});
|
||||
REQUIRE(clf.graph("Test") == graph);
|
||||
SECTION("Test TAN")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.getNumberOfStates() == 19);
|
||||
REQUIRE(clf.getClassNumStates() == 3);
|
||||
REQUIRE(clf.show() == std::vector<std::string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "});
|
||||
REQUIRE(clf.graph("Test") == graph);
|
||||
}
|
||||
SECTION("Test TANLd")
|
||||
{
|
||||
auto clf = bayesnet::TANLd();
|
||||
auto raw = RawDatasets("iris", false);
|
||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||
REQUIRE(clf.getNumberOfNodes() == 5);
|
||||
REQUIRE(clf.getNumberOfEdges() == 7);
|
||||
REQUIRE(clf.getNumberOfStates() == 19);
|
||||
REQUIRE(clf.getClassNumStates() == 3);
|
||||
REQUIRE(clf.show() == std::vector<std::string>{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "});
|
||||
REQUIRE(clf.graph("Test") == graph);
|
||||
}
|
||||
}
|
||||
TEST_CASE("Get num features & num edges", "[Models]")
|
||||
{
|
||||
@ -222,6 +238,12 @@ TEST_CASE("KDB with hyperparameters", "[Models]")
|
||||
REQUIRE(score == Catch::Approx(0.827103).epsilon(raw.epsilon));
|
||||
REQUIRE(scoret == Catch::Approx(0.761682).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("Incorrect type of data for SPODELd", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::SPODELd(0);
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
|
||||
}
|
||||
TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
|
||||
{
|
||||
auto clf = bayesnet::AODE();
|
||||
|
@ -157,18 +157,13 @@ TEST_CASE("Bisection", "[BoostAODE]")
|
||||
TEST_CASE("Block Update", "[BoostAODE]")
|
||||
{
|
||||
auto clf = bayesnet::BoostAODE();
|
||||
// auto raw = RawDatasets("mfeat-factors", true);
|
||||
auto raw = RawDatasets("glass", true);
|
||||
auto raw = RawDatasets("mfeat-factors", true);
|
||||
clf.setHyperparameters({
|
||||
{"bisection", true},
|
||||
{"block_update", true},
|
||||
{"maxTolerance", 3},
|
||||
{"convergence", true},
|
||||
});
|
||||
// clf.setHyperparameters({
|
||||
// {"block_update", true},
|
||||
// });
|
||||
|
||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||
REQUIRE(clf.getNumberOfNodes() == 217);
|
||||
REQUIRE(clf.getNumberOfEdges() == 431);
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/matchers/catch_matchers.hpp>
|
||||
#include "bayesnet/utils/BayesMetrics.h"
|
||||
#include "bayesnet/feature_selection/CFS.h"
|
||||
#include "bayesnet/feature_selection/FCBF.h"
|
||||
@ -68,4 +69,15 @@ TEST_CASE("Features Selected", "[FeatureSelection]")
|
||||
delete featureSelector;
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_CASE("Oddities", "[FeatureSelection]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
// FCBF Limits
|
||||
REQUIRE_THROWS_AS(bayesnet::FCBF(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, 1e-8), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(bayesnet::FCBF(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, 1e-8), "Threshold cannot be less than 1e-7");
|
||||
REQUIRE_THROWS_AS(bayesnet::IWSS(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, -1e4), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(bayesnet::IWSS(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, -1e4), "Threshold has to be in [0, 0.5]");
|
||||
REQUIRE_THROWS_AS(bayesnet::IWSS(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, 0.501), std::invalid_argument);
|
||||
REQUIRE_THROWS_WITH(bayesnet::IWSS(raw.dataset, raw.featuresv, raw.classNamev, raw.featuresv.size(), raw.classNumStates, raw.weights, 0.501), "Threshold has to be in [0, 0.5]");
|
||||
}
|
Loading…
Reference in New Issue
Block a user