Add tests to reach 90% coverage

This commit is contained in:
2024-04-08 00:13:59 +02:00
parent 46cb8d30eb
commit 0d6a081d01
13 changed files with 424 additions and 56 deletions

View File

@@ -1,27 +1,35 @@
#include <thread>
#include <mutex>
#include <sstream>
#include "Network.h"
#include "bayesnet/utils/bayesnetUtils.h"
namespace bayesnet {
Network::Network() : features(std::vector<std::string>()), className(""), classNumStates(0), fitted(false), laplaceSmoothing(0) {}
Network::Network(float maxT) : features(std::vector<std::string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false), laplaceSmoothing(0) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.
getmaxThreads()), fitted(other.fitted)
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
{
}
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
{
}
Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
{
if (samples.defined())
samples = samples.clone();
for (const auto& node : other.nodes) {
nodes[node.first] = std::make_unique<Node>(*node.second);
}
}
void Network::initialize()
{
features = std::vector<std::string>();
features.clear();
className = "";
classNumStates = 0;
fitted = false;
nodes.clear();
samples = torch::Tensor();
}
float Network::getmaxThreads()
float Network::getMaxThreads() const
{
return maxThreads;
}
@@ -114,11 +122,14 @@ namespace bayesnet {
if (n_features != featureNames.size()) {
throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
}
if (features.size() == 0) {
throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
}
if (n_features != features.size() - 1) {
throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
}
if (find(features.begin(), features.end(), className) == features.end()) {
throw std::invalid_argument("className not found in Network::features");
throw std::invalid_argument("Class Name not found in Network::features");
}
for (auto& feature : featureNames) {
if (find(features.begin(), features.end(), feature) == features.end()) {
@@ -404,11 +415,13 @@ namespace bayesnet {
}
return result;
}
void Network::dump_cpt() const
std::string Network::dump_cpt() const
{
std::stringstream oss;
for (auto& node : nodes) {
std::cout << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
std::cout << node.second->getCPT() << std::endl;
oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
oss << node.second->getCPT() << std::endl;
}
return oss.str();
}
}