bisection proposal #24
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -16,7 +16,7 @@
|
|||||||
"name": "test",
|
"name": "test",
|
||||||
"program": "${workspaceFolder}/build_debug/tests/TestBayesNet",
|
"program": "${workspaceFolder}/build_debug/tests/TestBayesNet",
|
||||||
"args": [
|
"args": [
|
||||||
"[FeatureSelection]"
|
"[Network]"
|
||||||
//"-c=\"Metrics Test\"",
|
//"-c=\"Metrics Test\"",
|
||||||
// "-s",
|
// "-s",
|
||||||
],
|
],
|
||||||
|
18
Makefile
18
Makefile
@ -1,6 +1,6 @@
|
|||||||
SHELL := /bin/bash
|
SHELL := /bin/bash
|
||||||
.DEFAULT_GOAL := help
|
.DEFAULT_GOAL := help
|
||||||
.PHONY: coverage setup help buildr buildd test clean debug release sample
|
.PHONY: viewcoverage coverage setup help install uninstall buildr buildd test clean debug release sample
|
||||||
|
|
||||||
f_release = build_release
|
f_release = build_release
|
||||||
f_debug = build_debug
|
f_debug = build_debug
|
||||||
@ -29,6 +29,7 @@ setup: ## Install dependencies for tests and coverage
|
|||||||
fi
|
fi
|
||||||
@if [ "$(shell uname)" = "Linux" ]; then \
|
@if [ "$(shell uname)" = "Linux" ]; then \
|
||||||
pip install gcovr; \
|
pip install gcovr; \
|
||||||
|
sudo dnf install lcov;\
|
||||||
fi
|
fi
|
||||||
|
|
||||||
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
||||||
@ -100,6 +101,21 @@ coverage: ## Run tests and generate coverage report (build/index.html)
|
|||||||
@gcovr $(f_debug)/tests
|
@gcovr $(f_debug)/tests
|
||||||
@echo ">>> Done";
|
@echo ">>> Done";
|
||||||
|
|
||||||
|
viewcoverage: ## Run tests, generate coverage report and upload it to codecov (build/index.html)
|
||||||
|
@echo ">>> Building tests with coverage..."
|
||||||
|
@$(MAKE) coverage
|
||||||
|
@echo ">>> Building report..."
|
||||||
|
@cd $(f_debug)/tests; \
|
||||||
|
lcov --directory . --capture --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
lcov --remove coverage.info '/usr/*' --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
lcov --remove coverage.info 'lib/*' --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
lcov --remove coverage.info 'libtorch/*' --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
lcov --remove coverage.info 'tests/*' --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
lcov --remove coverage.info 'bayesnet/utils/loguru.*' --output-file coverage.info >/dev/null 2>&1; \
|
||||||
|
genhtml coverage.info --output-directory $(f_debug)/tests/coverage >/dev/null 2>&1; \
|
||||||
|
xdg-open $(f_debug)/tests/coverage/index.html || open $(f_debug)/tests/coverage/index.html 2>/dev/null
|
||||||
|
@echo ">>> Done";
|
||||||
|
|
||||||
|
|
||||||
help: ## Show help message
|
help: ## Show help message
|
||||||
@IFS=$$'\n' ; \
|
@IFS=$$'\n' ; \
|
||||||
|
@ -3,19 +3,6 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
|
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
|
||||||
{
|
{
|
||||||
validHyperparameters = { "predict_voting" };
|
|
||||||
|
|
||||||
}
|
|
||||||
void AODELd::setHyperparameters(const nlohmann::json& hyperparameters_)
|
|
||||||
{
|
|
||||||
auto hyperparameters = hyperparameters_;
|
|
||||||
if (hyperparameters.contains("predict_voting")) {
|
|
||||||
predict_voting = hyperparameters["predict_voting"];
|
|
||||||
hyperparameters.erase("predict_voting");
|
|
||||||
}
|
|
||||||
if (!hyperparameters.empty()) {
|
|
||||||
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
|
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
|
||||||
{
|
{
|
||||||
|
@ -10,7 +10,6 @@ namespace bayesnet {
|
|||||||
AODELd(bool predict_voting = true);
|
AODELd(bool predict_voting = true);
|
||||||
virtual ~AODELd() = default;
|
virtual ~AODELd() = default;
|
||||||
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
|
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
|
||||||
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
|
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
|
||||||
protected:
|
protected:
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
|
@ -1,27 +1,35 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
#include "bayesnet/utils/bayesnetUtils.h"
|
#include "bayesnet/utils/bayesnetUtils.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : features(std::vector<std::string>()), className(""), classNumStates(0), fitted(false), laplaceSmoothing(0) {}
|
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
|
||||||
Network::Network(float maxT) : features(std::vector<std::string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false), laplaceSmoothing(0) {}
|
|
||||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.
|
|
||||||
getmaxThreads()), fitted(other.fitted)
|
|
||||||
{
|
{
|
||||||
|
}
|
||||||
|
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||||
|
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
|
||||||
|
{
|
||||||
|
if (samples.defined())
|
||||||
|
samples = samples.clone();
|
||||||
for (const auto& node : other.nodes) {
|
for (const auto& node : other.nodes) {
|
||||||
nodes[node.first] = std::make_unique<Node>(*node.second);
|
nodes[node.first] = std::make_unique<Node>(*node.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void Network::initialize()
|
void Network::initialize()
|
||||||
{
|
{
|
||||||
features = std::vector<std::string>();
|
features.clear();
|
||||||
className = "";
|
className = "";
|
||||||
classNumStates = 0;
|
classNumStates = 0;
|
||||||
fitted = false;
|
fitted = false;
|
||||||
nodes.clear();
|
nodes.clear();
|
||||||
samples = torch::Tensor();
|
samples = torch::Tensor();
|
||||||
}
|
}
|
||||||
float Network::getmaxThreads()
|
float Network::getMaxThreads() const
|
||||||
{
|
{
|
||||||
return maxThreads;
|
return maxThreads;
|
||||||
}
|
}
|
||||||
@ -114,11 +122,14 @@ namespace bayesnet {
|
|||||||
if (n_features != featureNames.size()) {
|
if (n_features != featureNames.size()) {
|
||||||
throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
|
throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
|
||||||
}
|
}
|
||||||
|
if (features.size() == 0) {
|
||||||
|
throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
|
||||||
|
}
|
||||||
if (n_features != features.size() - 1) {
|
if (n_features != features.size() - 1) {
|
||||||
throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
|
throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
|
||||||
}
|
}
|
||||||
if (find(features.begin(), features.end(), className) == features.end()) {
|
if (find(features.begin(), features.end(), className) == features.end()) {
|
||||||
throw std::invalid_argument("className not found in Network::features");
|
throw std::invalid_argument("Class Name not found in Network::features");
|
||||||
}
|
}
|
||||||
for (auto& feature : featureNames) {
|
for (auto& feature : featureNames) {
|
||||||
if (find(features.begin(), features.end(), feature) == features.end()) {
|
if (find(features.begin(), features.end(), feature) == features.end()) {
|
||||||
@ -404,11 +415,13 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
void Network::dump_cpt() const
|
std::string Network::dump_cpt() const
|
||||||
{
|
{
|
||||||
|
std::stringstream oss;
|
||||||
for (auto& node : nodes) {
|
for (auto& node : nodes) {
|
||||||
std::cout << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
|
oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
|
||||||
std::cout << node.second->getCPT() << std::endl;
|
oss << node.second->getCPT() << std::endl;
|
||||||
}
|
}
|
||||||
|
return oss.str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,10 +10,10 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
Network();
|
Network();
|
||||||
explicit Network(float);
|
explicit Network(float);
|
||||||
explicit Network(Network&);
|
explicit Network(const Network&);
|
||||||
~Network() = default;
|
~Network() = default;
|
||||||
torch::Tensor& getSamples();
|
torch::Tensor& getSamples();
|
||||||
float getmaxThreads();
|
float getMaxThreads() const;
|
||||||
void addNode(const std::string&);
|
void addNode(const std::string&);
|
||||||
void addEdge(const std::string&, const std::string&);
|
void addEdge(const std::string&, const std::string&);
|
||||||
std::map<std::string, std::unique_ptr<Node>>& getNodes();
|
std::map<std::string, std::unique_ptr<Node>>& getNodes();
|
||||||
@ -39,7 +39,7 @@ namespace bayesnet {
|
|||||||
std::vector<std::string> show() const;
|
std::vector<std::string> show() const;
|
||||||
std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
|
std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
|
||||||
void initialize();
|
void initialize();
|
||||||
void dump_cpt() const;
|
std::string dump_cpt() const;
|
||||||
inline std::string version() { return { project_version.begin(), project_version.end() }; }
|
inline std::string version() { return { project_version.begin(), project_version.end() }; }
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::unique_ptr<Node>> nodes;
|
std::map<std::string, std::unique_ptr<Node>> nodes;
|
||||||
@ -49,7 +49,7 @@ namespace bayesnet {
|
|||||||
std::vector<std::string> features; // Including classname
|
std::vector<std::string> features; // Including classname
|
||||||
std::string className;
|
std::string className;
|
||||||
double laplaceSmoothing;
|
double laplaceSmoothing;
|
||||||
torch::Tensor samples; // nxm tensor used to fit the model
|
torch::Tensor samples; // n+1xm tensor used to fit the model
|
||||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||||
std::vector<double> predict_sample(const std::vector<int>&);
|
std::vector<double> predict_sample(const std::vector<int>&);
|
||||||
std::vector<double> predict_sample(const torch::Tensor&);
|
std::vector<double> predict_sample(const torch::Tensor&);
|
||||||
|
@ -9,12 +9,12 @@ namespace bayesnet {
|
|||||||
, classNumStates(classNumStates)
|
, classNumStates(classNumStates)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
//samples is nxm std::vector used to fit the model
|
//samples is n+1xm std::vector used to fit the model
|
||||||
Metrics::Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates)
|
Metrics::Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates)
|
||||||
: features(features)
|
: features(features)
|
||||||
, className(className)
|
, className(className)
|
||||||
, classNumStates(classNumStates)
|
, classNumStates(classNumStates)
|
||||||
, samples(torch::zeros({ static_cast<int>(vsamples[0].size()), static_cast<int>(vsamples.size() + 1) }, torch::kInt32))
|
, samples(torch::zeros({ static_cast<int>(vsamples.size() + 1), static_cast<int>(vsamples[0].size()) }, torch::kInt32))
|
||||||
{
|
{
|
||||||
for (int i = 0; i < vsamples.size(); ++i) {
|
for (int i = 0; i < vsamples.size(); ++i) {
|
||||||
samples.index_put_({ i, "..." }, torch::tensor(vsamples[i], torch::kInt32));
|
samples.index_put_({ i, "..." }, torch::tensor(vsamples[i], torch::kInt32));
|
||||||
|
@ -5,11 +5,16 @@
|
|||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class Metrics {
|
class Metrics {
|
||||||
private:
|
public:
|
||||||
int classNumStates = 0;
|
Metrics() = default;
|
||||||
std::vector<double> scoresKBest;
|
Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
|
||||||
std::vector<int> featuresKBest; // sorted indices of the features
|
Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
|
||||||
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
|
std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
|
||||||
|
std::vector<double> getScoresKBest() const;
|
||||||
|
double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
|
||||||
|
std::vector<float> conditionalEdgeWeights(std::vector<float>& weights); // To use in Python
|
||||||
|
torch::Tensor conditionalEdge(const torch::Tensor& weights);
|
||||||
|
std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
|
||||||
protected:
|
protected:
|
||||||
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
|
torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
|
||||||
std::string className;
|
std::string className;
|
||||||
@ -34,16 +39,11 @@ namespace bayesnet {
|
|||||||
v.erase(v.begin());
|
v.erase(v.begin());
|
||||||
return temp;
|
return temp;
|
||||||
}
|
}
|
||||||
public:
|
private:
|
||||||
Metrics() = default;
|
int classNumStates = 0;
|
||||||
Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
|
std::vector<double> scoresKBest;
|
||||||
Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
|
std::vector<int> featuresKBest; // sorted indices of the features
|
||||||
std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
|
double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
|
||||||
std::vector<double> getScoresKBest() const;
|
|
||||||
double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
|
|
||||||
std::vector<float> conditionalEdgeWeights(std::vector<float>& weights); // To use in Python
|
|
||||||
torch::Tensor conditionalEdge(const torch::Tensor& weights);
|
|
||||||
std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -8,10 +8,11 @@ if(ENABLE_TESTING)
|
|||||||
${CMAKE_BINARY_DIR}/configured_files/include
|
${CMAKE_BINARY_DIR}/configured_files/include
|
||||||
)
|
)
|
||||||
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
|
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
|
||||||
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
|
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
|
||||||
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
|
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
|
||||||
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
|
||||||
add_test(NAME Network COMMAND TestBayesNet "[Network]")
|
add_test(NAME Network COMMAND TestBayesNet "[Network]")
|
||||||
|
add_test(NAME Node COMMAND TestBayesNet "[Node]")
|
||||||
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
|
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
|
||||||
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
|
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
|
||||||
add_test(NAME Models COMMAND TestBayesNet "[Models]")
|
add_test(NAME Models COMMAND TestBayesNet "[Models]")
|
||||||
|
@ -32,31 +32,41 @@ TEST_CASE("Metrics Test", "[Metrics]")
|
|||||||
};
|
};
|
||||||
auto raw = RawDatasets(file_name, true);
|
auto raw = RawDatasets(file_name, true);
|
||||||
bayesnet::Metrics metrics(raw.dataset, raw.featurest, raw.classNamet, raw.classNumStates);
|
bayesnet::Metrics metrics(raw.dataset, raw.featurest, raw.classNamet, raw.classNumStates);
|
||||||
|
bayesnet::Metrics metricsv(raw.Xv, raw.yv, raw.featurest, raw.classNamet, raw.classNumStates);
|
||||||
|
|
||||||
SECTION("Test Constructor")
|
SECTION("Test Constructor")
|
||||||
{
|
{
|
||||||
REQUIRE(metrics.getScoresKBest().size() == 0);
|
REQUIRE(metrics.getScoresKBest().size() == 0);
|
||||||
|
REQUIRE(metricsv.getScoresKBest().size() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
SECTION("Test SelectKBestWeighted")
|
SECTION("Test SelectKBestWeighted")
|
||||||
{
|
{
|
||||||
std::vector<int> kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
std::vector<int> kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
||||||
|
std::vector<int> kBestv = metricsv.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first);
|
||||||
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
|
REQUIRE(kBest.size() == resultsKBest.at(file_name).first);
|
||||||
|
REQUIRE(kBestv.size() == resultsKBest.at(file_name).first);
|
||||||
REQUIRE(kBest == resultsKBest.at(file_name).second);
|
REQUIRE(kBest == resultsKBest.at(file_name).second);
|
||||||
|
REQUIRE(kBestv == resultsKBest.at(file_name).second);
|
||||||
}
|
}
|
||||||
|
|
||||||
SECTION("Test Mutual Information")
|
SECTION("Test Mutual Information")
|
||||||
{
|
{
|
||||||
auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
||||||
|
auto resultv = metricsv.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights);
|
||||||
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
||||||
|
REQUIRE(resultv == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon));
|
||||||
}
|
}
|
||||||
|
|
||||||
SECTION("Test Maximum Spanning Tree")
|
SECTION("Test Maximum Spanning Tree")
|
||||||
{
|
{
|
||||||
auto weights_matrix = metrics.conditionalEdge(raw.weights);
|
auto weights_matrix = metrics.conditionalEdge(raw.weights);
|
||||||
|
auto weights_matrixv = metricsv.conditionalEdge(raw.weights);
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
auto result = metrics.maximumSpanningTree(raw.featurest, weights_matrix, i);
|
auto result = metrics.maximumSpanningTree(raw.featurest, weights_matrix, i);
|
||||||
|
auto resultv = metricsv.maximumSpanningTree(raw.featurest, weights_matrixv, i);
|
||||||
REQUIRE(result == resultsMST.at({ file_name, i }));
|
REQUIRE(result == resultsMST.at({ file_name, i }));
|
||||||
|
REQUIRE(resultv == resultsMST.at({ file_name, i }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
#include <type_traits>
|
||||||
#include <catch2/catch_test_macros.hpp>
|
#include <catch2/catch_test_macros.hpp>
|
||||||
#include <catch2/catch_approx.hpp>
|
#include <catch2/catch_approx.hpp>
|
||||||
#include <catch2/generators/catch_generators.hpp>
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
@ -98,6 +99,30 @@ TEST_CASE("BoostAODE feature_select CFS", "[Models]")
|
|||||||
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
|
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
|
||||||
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
||||||
}
|
}
|
||||||
|
TEST_CASE("BoostAODE feature_select IWSS", "[Models]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("glass", true);
|
||||||
|
auto clf = bayesnet::BoostAODE();
|
||||||
|
clf.setHyperparameters({ {"select_features", "IWSS"}, {"threshold", 0.5 } });
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
REQUIRE(clf.getNumberOfNodes() == 90);
|
||||||
|
REQUIRE(clf.getNumberOfEdges() == 153);
|
||||||
|
REQUIRE(clf.getNotes().size() == 2);
|
||||||
|
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 5 of 9 with IWSS");
|
||||||
|
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
||||||
|
}
|
||||||
|
TEST_CASE("BoostAODE feature_select FCBF", "[Models]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("glass", true);
|
||||||
|
auto clf = bayesnet::BoostAODE();
|
||||||
|
clf.setHyperparameters({ {"select_features", "FCBF"}, {"threshold", 1e-7 } });
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
REQUIRE(clf.getNumberOfNodes() == 90);
|
||||||
|
REQUIRE(clf.getNumberOfEdges() == 153);
|
||||||
|
REQUIRE(clf.getNotes().size() == 2);
|
||||||
|
REQUIRE(clf.getNotes()[0] == "Used features in initialization: 5 of 9 with FCBF");
|
||||||
|
REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
||||||
|
}
|
||||||
TEST_CASE("BoostAODE test used features in train note and score", "[Models]")
|
TEST_CASE("BoostAODE test used features in train note and score", "[Models]")
|
||||||
{
|
{
|
||||||
auto raw = RawDatasets("diabetes", true);
|
auto raw = RawDatasets("diabetes", true);
|
||||||
@ -246,7 +271,7 @@ TEST_CASE("SPODELd dataset", "[Models]")
|
|||||||
{
|
{
|
||||||
auto raw = RawDatasets("iris", false);
|
auto raw = RawDatasets("iris", false);
|
||||||
auto clf = bayesnet::SPODELd(0);
|
auto clf = bayesnet::SPODELd(0);
|
||||||
raw.dataset.to(torch::kFloat32);
|
// raw.dataset.to(torch::kFloat32);
|
||||||
clf.fit(raw.dataset, raw.featuresv, raw.classNamev, raw.statesv);
|
clf.fit(raw.dataset, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
auto score = clf.score(raw.Xt, raw.yt);
|
auto score = clf.score(raw.Xt, raw.yt);
|
||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
#include <catch2/catch_test_macros.hpp>
|
#include <catch2/catch_test_macros.hpp>
|
||||||
#include <catch2/catch_approx.hpp>
|
#include <catch2/catch_approx.hpp>
|
||||||
#include <catch2/generators/catch_generators.hpp>
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
#include "bayesnet/network/Network.h"
|
#include "bayesnet/network/Network.h"
|
||||||
|
#include "bayesnet/utils/bayesnetUtils.h"
|
||||||
|
|
||||||
void buildModel(bayesnet::Network& net, const std::vector<std::string>& features, const std::string& className)
|
void buildModel(bayesnet::Network& net, const std::vector<std::string>& features, const std::string& className)
|
||||||
{
|
{
|
||||||
@ -111,6 +113,22 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
net3.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest);
|
net3.fit(raw.Xt, raw.yt, raw.weights, raw.featurest, raw.classNamet, raw.statest);
|
||||||
REQUIRE(net.getStates() == net2.getStates());
|
REQUIRE(net.getStates() == net2.getStates());
|
||||||
REQUIRE(net.getStates() == net3.getStates());
|
REQUIRE(net.getStates() == net3.getStates());
|
||||||
|
REQUIRE(net.getFeatures() == net2.getFeatures());
|
||||||
|
REQUIRE(net.getFeatures() == net3.getFeatures());
|
||||||
|
REQUIRE(net.getClassName() == net2.getClassName());
|
||||||
|
REQUIRE(net.getClassName() == net3.getClassName());
|
||||||
|
REQUIRE(net.getNodes().size() == net2.getNodes().size());
|
||||||
|
REQUIRE(net.getNodes().size() == net3.getNodes().size());
|
||||||
|
REQUIRE(net.getEdges() == net2.getEdges());
|
||||||
|
REQUIRE(net.getEdges() == net3.getEdges());
|
||||||
|
REQUIRE(net.getNumEdges() == net2.getNumEdges());
|
||||||
|
REQUIRE(net.getNumEdges() == net3.getNumEdges());
|
||||||
|
REQUIRE(net.getClassNumStates() == net2.getClassNumStates());
|
||||||
|
REQUIRE(net.getClassNumStates() == net3.getClassNumStates());
|
||||||
|
REQUIRE(net.getSamples().size(0) == net2.getSamples().size(0));
|
||||||
|
REQUIRE(net.getSamples().size(0) == net3.getSamples().size(0));
|
||||||
|
REQUIRE(net.getSamples().size(1) == net2.getSamples().size(1));
|
||||||
|
REQUIRE(net.getSamples().size(1) == net3.getSamples().size(1));
|
||||||
// Check Conditional Probabilities tables
|
// Check Conditional Probabilities tables
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
auto feature = features.at(i);
|
auto feature = features.at(i);
|
||||||
@ -125,7 +143,6 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test show")
|
SECTION("Test show")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
net.addNode("A");
|
net.addNode("A");
|
||||||
net.addNode("B");
|
net.addNode("B");
|
||||||
net.addNode("C");
|
net.addNode("C");
|
||||||
@ -139,7 +156,6 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test topological_sort")
|
SECTION("Test topological_sort")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
net.addNode("A");
|
net.addNode("A");
|
||||||
net.addNode("B");
|
net.addNode("B");
|
||||||
net.addNode("C");
|
net.addNode("C");
|
||||||
@ -153,7 +169,6 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test graph")
|
SECTION("Test graph")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
net.addNode("A");
|
net.addNode("A");
|
||||||
net.addNode("B");
|
net.addNode("B");
|
||||||
net.addNode("C");
|
net.addNode("C");
|
||||||
@ -171,7 +186,6 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test predict")
|
SECTION("Test predict")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
buildModel(net, raw.featuresv, raw.classNamev);
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||||
@ -181,7 +195,6 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test predict_proba")
|
SECTION("Test predict_proba")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
buildModel(net, raw.featuresv, raw.classNamev);
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||||
@ -203,10 +216,230 @@ TEST_CASE("Test Bayesian Network", "[Network]")
|
|||||||
}
|
}
|
||||||
SECTION("Test score")
|
SECTION("Test score")
|
||||||
{
|
{
|
||||||
auto net = bayesnet::Network();
|
|
||||||
buildModel(net, raw.featuresv, raw.classNamev);
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
auto score = net.score(raw.Xv, raw.yv);
|
auto score = net.score(raw.Xv, raw.yv);
|
||||||
REQUIRE(score == Catch::Approx(0.97333333).margin(threshold));
|
REQUIRE(score == Catch::Approx(0.97333333).margin(threshold));
|
||||||
}
|
}
|
||||||
|
SECTION("Copy constructor")
|
||||||
|
{
|
||||||
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
auto net2 = bayesnet::Network(net);
|
||||||
|
REQUIRE(net.getFeatures() == net2.getFeatures());
|
||||||
|
REQUIRE(net.getEdges() == net2.getEdges());
|
||||||
|
REQUIRE(net.getNumEdges() == net2.getNumEdges());
|
||||||
|
REQUIRE(net.getStates() == net2.getStates());
|
||||||
|
REQUIRE(net.getClassName() == net2.getClassName());
|
||||||
|
REQUIRE(net.getClassNumStates() == net2.getClassNumStates());
|
||||||
|
REQUIRE(net.getSamples().size(0) == net2.getSamples().size(0));
|
||||||
|
REQUIRE(net.getSamples().size(1) == net2.getSamples().size(1));
|
||||||
|
REQUIRE(net.getNodes().size() == net2.getNodes().size());
|
||||||
|
for (const auto& feature : net.getFeatures()) {
|
||||||
|
auto& node = net.getNodes().at(feature);
|
||||||
|
auto& node2 = net2.getNodes().at(feature);
|
||||||
|
REQUIRE(node->getName() == node2->getName());
|
||||||
|
REQUIRE(node->getChildren().size() == node2->getChildren().size());
|
||||||
|
REQUIRE(node->getParents().size() == node2->getParents().size());
|
||||||
|
REQUIRE(node->getCPT().equal(node2->getCPT()));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
SECTION("Test oddities")
|
||||||
|
{
|
||||||
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
|
// predict without fitting
|
||||||
|
std::vector<std::vector<int>> test = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1}, {2, 2, 2, 2, 1} };
|
||||||
|
auto test_tensor = bayesnet::vectorToTensor(test);
|
||||||
|
REQUIRE_THROWS_AS(net.predict(test), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(net.predict(test), "You must call fit() before calling predict()");
|
||||||
|
REQUIRE_THROWS_AS(net.predict(test_tensor), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(net.predict(test_tensor), "You must call fit() before calling predict()");
|
||||||
|
REQUIRE_THROWS_AS(net.predict_proba(test), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(net.predict_proba(test), "You must call fit() before calling predict_proba()");
|
||||||
|
REQUIRE_THROWS_AS(net.score(raw.Xv, raw.yv), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(net.score(raw.Xv, raw.yv), "You must call fit() before calling predict()");
|
||||||
|
// predict with wrong data
|
||||||
|
auto netx = bayesnet::Network();
|
||||||
|
buildModel(netx, raw.featuresv, raw.classNamev);
|
||||||
|
netx.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
std::vector<std::vector<int>> test2 = { {1, 2, 0, 1, 1}, {0, 1, 2, 0, 1}, {0, 0, 0, 0, 1} };
|
||||||
|
auto test_tensor2 = bayesnet::vectorToTensor(test2, false);
|
||||||
|
REQUIRE_THROWS_AS(netx.predict(test2), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(netx.predict(test2), "Sample size (3) does not match the number of features (4)");
|
||||||
|
REQUIRE_THROWS_AS(netx.predict(test_tensor2), std::logic_error);
|
||||||
|
REQUIRE_THROWS_WITH(netx.predict(test_tensor2), "Sample size (3) does not match the number of features (4)");
|
||||||
|
// fit with wrong data
|
||||||
|
// Weights
|
||||||
|
auto net2 = bayesnet::Network();
|
||||||
|
REQUIRE_THROWS_AS(net2.fit(raw.Xv, raw.yv, std::vector<double>(), raw.featuresv, raw.classNamev, raw.statesv), std::invalid_argument);
|
||||||
|
std::string invalid_weights = "Weights (0) must have the same number of elements as samples (150) in Network::fit";
|
||||||
|
REQUIRE_THROWS_WITH(net2.fit(raw.Xv, raw.yv, std::vector<double>(), raw.featuresv, raw.classNamev, raw.statesv), invalid_weights);
|
||||||
|
// X & y
|
||||||
|
std::string invalid_labels = "X and y must have the same number of samples in Network::fit (150 != 0)";
|
||||||
|
REQUIRE_THROWS_AS(net2.fit(raw.Xv, std::vector<int>(), raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net2.fit(raw.Xv, std::vector<int>(), raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv), invalid_labels);
|
||||||
|
// Features
|
||||||
|
std::string invalid_features = "X and features must have the same number of features in Network::fit (4 != 0)";
|
||||||
|
REQUIRE_THROWS_AS(net2.fit(raw.Xv, raw.yv, raw.weightsv, std::vector<std::string>(), raw.classNamev, raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net2.fit(raw.Xv, raw.yv, raw.weightsv, std::vector<std::string>(), raw.classNamev, raw.statesv), invalid_features);
|
||||||
|
// Different number of features
|
||||||
|
auto net3 = bayesnet::Network();
|
||||||
|
auto test2y = { 1, 2, 3, 4, 5 };
|
||||||
|
buildModel(net3, raw.featuresv, raw.classNamev);
|
||||||
|
auto features3 = raw.featuresv;
|
||||||
|
features3.pop_back();
|
||||||
|
std::string invalid_features2 = "X and local features must have the same number of features in Network::fit (3 != 4)";
|
||||||
|
REQUIRE_THROWS_AS(net3.fit(test2, test2y, std::vector<double>(5, 0), features3, raw.classNamev, raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net3.fit(test2, test2y, std::vector<double>(5, 0), features3, raw.classNamev, raw.statesv), invalid_features2);
|
||||||
|
// Uninitialized network
|
||||||
|
std::string network_invalid = "The network has not been initialized. You must call addNode() before calling fit()";
|
||||||
|
REQUIRE_THROWS_AS(net2.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, "duck", raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net2.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, "duck", raw.statesv), network_invalid);
|
||||||
|
// Classname
|
||||||
|
std::string invalid_classname = "Class Name not found in Network::features";
|
||||||
|
REQUIRE_THROWS_AS(net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, "duck", raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, "duck", raw.statesv), invalid_classname);
|
||||||
|
// Invalid feature
|
||||||
|
auto features2 = raw.featuresv;
|
||||||
|
features2.pop_back();
|
||||||
|
features2.push_back("duck");
|
||||||
|
std::string invalid_feature = "Feature duck not found in Network::features";
|
||||||
|
REQUIRE_THROWS_AS(net.fit(raw.Xv, raw.yv, raw.weightsv, features2, raw.classNamev, raw.statesv), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.fit(raw.Xv, raw.yv, raw.weightsv, features2, raw.classNamev, raw.statesv), invalid_feature);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
TEST_CASE("Test and empty Node", "[Network]")
|
||||||
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
|
REQUIRE_THROWS_AS(net.addNode(""), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.addNode(""), "Node name cannot be empty");
|
||||||
|
}
|
||||||
|
TEST_CASE("Cicle in Network", "[Network]")
|
||||||
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
|
net.addNode("A");
|
||||||
|
net.addNode("B");
|
||||||
|
net.addNode("C");
|
||||||
|
net.addEdge("A", "B");
|
||||||
|
net.addEdge("B", "C");
|
||||||
|
REQUIRE_THROWS_AS(net.addEdge("C", "A"), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.addEdge("C", "A"), "Adding this edge forms a cycle in the graph.");
|
||||||
|
}
|
||||||
|
TEST_CASE("Test max threads constructor", "[Network]")
|
||||||
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
|
REQUIRE(net.getMaxThreads() == 0.95f);
|
||||||
|
auto net2 = bayesnet::Network(4);
|
||||||
|
REQUIRE(net2.getMaxThreads() == 4);
|
||||||
|
auto net3 = bayesnet::Network(1.75);
|
||||||
|
REQUIRE(net3.getMaxThreads() == 1.75);
|
||||||
|
}
|
||||||
|
TEST_CASE("Edges troubles", "[Network]")
|
||||||
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
|
net.addNode("A");
|
||||||
|
net.addNode("B");
|
||||||
|
REQUIRE_THROWS_AS(net.addEdge("A", "C"), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.addEdge("A", "C"), "Child node C does not exist");
|
||||||
|
REQUIRE_THROWS_AS(net.addEdge("C", "A"), std::invalid_argument);
|
||||||
|
REQUIRE_THROWS_WITH(net.addEdge("C", "A"), "Parent node C does not exist");
|
||||||
|
}
|
||||||
|
TEST_CASE("Dump CPT", "[Network]")
|
||||||
|
{
|
||||||
|
auto net = bayesnet::Network();
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
buildModel(net, raw.featuresv, raw.classNamev);
|
||||||
|
net.fit(raw.Xv, raw.yv, raw.weightsv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
auto res = net.dump_cpt();
|
||||||
|
std::string expected = R"(* class: (3) : [3]
|
||||||
|
0.3333
|
||||||
|
0.3333
|
||||||
|
0.3333
|
||||||
|
[ CPUFloatType{3} ]
|
||||||
|
* petallength: (4) : [4, 3, 3]
|
||||||
|
(1,.,.) =
|
||||||
|
0.9388 0.1000 0.2000
|
||||||
|
0.6250 0.0526 0.1667
|
||||||
|
0.4000 0.0303 0.0196
|
||||||
|
|
||||||
|
(2,.,.) =
|
||||||
|
0.0204 0.7000 0.4000
|
||||||
|
0.1250 0.8421 0.1667
|
||||||
|
0.2000 0.7273 0.0196
|
||||||
|
|
||||||
|
(3,.,.) =
|
||||||
|
0.0204 0.1000 0.2000
|
||||||
|
0.1250 0.0526 0.5000
|
||||||
|
0.2000 0.1818 0.1373
|
||||||
|
|
||||||
|
(4,.,.) =
|
||||||
|
0.0204 0.1000 0.2000
|
||||||
|
0.1250 0.0526 0.1667
|
||||||
|
0.2000 0.0606 0.8235
|
||||||
|
[ CPUFloatType{4,3,3} ]
|
||||||
|
* petalwidth: (3) : [3, 6, 3]
|
||||||
|
(1,.,.) =
|
||||||
|
0.5000 0.0417 0.0714
|
||||||
|
0.3333 0.1111 0.0909
|
||||||
|
0.5000 0.1000 0.2000
|
||||||
|
0.7778 0.0909 0.0667
|
||||||
|
0.8667 0.1000 0.0667
|
||||||
|
0.9394 0.2500 0.1250
|
||||||
|
|
||||||
|
(2,.,.) =
|
||||||
|
0.2500 0.9167 0.2857
|
||||||
|
0.3333 0.7778 0.1818
|
||||||
|
0.2500 0.8000 0.2000
|
||||||
|
0.1111 0.8182 0.1333
|
||||||
|
0.0667 0.7000 0.0667
|
||||||
|
0.0303 0.5000 0.1250
|
||||||
|
|
||||||
|
(3,.,.) =
|
||||||
|
0.2500 0.0417 0.6429
|
||||||
|
0.3333 0.1111 0.7273
|
||||||
|
0.2500 0.1000 0.6000
|
||||||
|
0.1111 0.0909 0.8000
|
||||||
|
0.0667 0.2000 0.8667
|
||||||
|
0.0303 0.2500 0.7500
|
||||||
|
[ CPUFloatType{3,6,3} ]
|
||||||
|
* sepallength: (3) : [3, 3]
|
||||||
|
0.8679 0.1321 0.0377
|
||||||
|
0.0943 0.3019 0.0566
|
||||||
|
0.0377 0.5660 0.9057
|
||||||
|
[ CPUFloatType{3,3} ]
|
||||||
|
* sepalwidth: (6) : [6, 3, 3]
|
||||||
|
(1,.,.) =
|
||||||
|
0.0392 0.5000 0.2857
|
||||||
|
0.1000 0.4286 0.2500
|
||||||
|
0.1429 0.2571 0.1887
|
||||||
|
|
||||||
|
(2,.,.) =
|
||||||
|
0.0196 0.0833 0.1429
|
||||||
|
0.1000 0.1429 0.2500
|
||||||
|
0.1429 0.1429 0.1509
|
||||||
|
|
||||||
|
(3,.,.) =
|
||||||
|
0.0392 0.0833 0.1429
|
||||||
|
0.1000 0.1429 0.1250
|
||||||
|
0.1429 0.1714 0.0566
|
||||||
|
|
||||||
|
(4,.,.) =
|
||||||
|
0.1373 0.1667 0.1429
|
||||||
|
0.1000 0.1905 0.1250
|
||||||
|
0.1429 0.1429 0.2453
|
||||||
|
|
||||||
|
(5,.,.) =
|
||||||
|
0.2549 0.0833 0.1429
|
||||||
|
0.1000 0.0476 0.1250
|
||||||
|
0.1429 0.2286 0.2453
|
||||||
|
|
||||||
|
(6,.,.) =
|
||||||
|
0.5098 0.0833 0.1429
|
||||||
|
0.5000 0.0476 0.1250
|
||||||
|
0.2857 0.0571 0.1132
|
||||||
|
[ CPUFloatType{6,3,3} ]
|
||||||
|
)";
|
||||||
|
REQUIRE(res == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
84
tests/TestBayesNode.cc
Normal file
84
tests/TestBayesNode.cc
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/catch_approx.hpp>
|
||||||
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include "TestUtils.h"
|
||||||
|
#include "bayesnet/network/Network.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE("Test Node children and parents", "[Node]")
|
||||||
|
{
|
||||||
|
auto node = bayesnet::Node("Node");
|
||||||
|
REQUIRE(node.getName() == "Node");
|
||||||
|
auto parent_1 = bayesnet::Node("P1");
|
||||||
|
auto parent_2 = bayesnet::Node("P2");
|
||||||
|
auto child_1 = bayesnet::Node("H1");
|
||||||
|
auto child_2 = bayesnet::Node("H2");
|
||||||
|
auto child_3 = bayesnet::Node("H3");
|
||||||
|
node.addParent(&parent_1);
|
||||||
|
node.addParent(&parent_2);
|
||||||
|
node.addChild(&child_1);
|
||||||
|
node.addChild(&child_2);
|
||||||
|
node.addChild(&child_3);
|
||||||
|
auto parents = node.getParents();
|
||||||
|
auto children = node.getChildren();
|
||||||
|
REQUIRE(parents.size() == 2);
|
||||||
|
REQUIRE(children.size() == 3);
|
||||||
|
REQUIRE(parents[0]->getName() == "P1");
|
||||||
|
REQUIRE(parents[1]->getName() == "P2");
|
||||||
|
REQUIRE(children[0]->getName() == "H1");
|
||||||
|
REQUIRE(children[1]->getName() == "H2");
|
||||||
|
REQUIRE(children[2]->getName() == "H3");
|
||||||
|
node.removeParent(&parent_1);
|
||||||
|
node.removeChild(&child_1);
|
||||||
|
parents = node.getParents();
|
||||||
|
children = node.getChildren();
|
||||||
|
REQUIRE(parents.size() == 1);
|
||||||
|
REQUIRE(children.size() == 2);
|
||||||
|
node.clear();
|
||||||
|
parents = node.getParents();
|
||||||
|
children = node.getChildren();
|
||||||
|
REQUIRE(parents.size() == 0);
|
||||||
|
REQUIRE(children.size() == 0);
|
||||||
|
}
|
||||||
|
TEST_CASE("TEST MinFill method", "[Node]")
|
||||||
|
{
|
||||||
|
// Generate a test to test the minFill method of the Node class
|
||||||
|
// Create a graph with 5 nodes
|
||||||
|
// The graph is a chain with some additional edges
|
||||||
|
// 0 -> 1,2,3
|
||||||
|
// 1 -> 2,4
|
||||||
|
// 2 -> 3
|
||||||
|
// 3 -> 4
|
||||||
|
auto node_0 = bayesnet::Node("0");
|
||||||
|
auto node_1 = bayesnet::Node("1");
|
||||||
|
auto node_2 = bayesnet::Node("2");
|
||||||
|
auto node_3 = bayesnet::Node("3");
|
||||||
|
auto node_4 = bayesnet::Node("4");
|
||||||
|
// node 0
|
||||||
|
node_0.addChild(&node_1);
|
||||||
|
node_0.addChild(&node_2);
|
||||||
|
node_0.addChild(&node_3);
|
||||||
|
// node 1
|
||||||
|
node_1.addChild(&node_2);
|
||||||
|
node_1.addChild(&node_4);
|
||||||
|
node_1.addParent(&node_0);
|
||||||
|
// node 2
|
||||||
|
node_2.addChild(&node_3);
|
||||||
|
node_2.addChild(&node_4);
|
||||||
|
node_2.addParent(&node_0);
|
||||||
|
node_2.addParent(&node_1);
|
||||||
|
// node 3
|
||||||
|
node_3.addChild(&node_4);
|
||||||
|
node_3.addParent(&node_0);
|
||||||
|
node_3.addParent(&node_2);
|
||||||
|
// node 4
|
||||||
|
node_4.addParent(&node_1);
|
||||||
|
node_4.addParent(&node_3);
|
||||||
|
REQUIRE(node_0.minFill() == 3);
|
||||||
|
REQUIRE(node_1.minFill() == 3);
|
||||||
|
REQUIRE(node_2.minFill() == 6);
|
||||||
|
REQUIRE(node_3.minFill() == 3);
|
||||||
|
REQUIRE(node_4.minFill() == 1);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user