Compare commits

...

2 Commits

12 changed files with 116 additions and 24 deletions

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
virtual std::string getVersion() = 0; virtual std::string getVersion() = 0;
std::vector<std::string> virtual topological_order() = 0; std::vector<std::string> virtual topological_order() = 0;
std::vector<std::string> virtual getNotes() const = 0; std::vector<std::string> virtual getNotes() const = 0;
void virtual dump_cpt()const = 0; std::string virtual dump_cpt()const = 0;
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; } std::vector<std::string>& getValidHyperparameters() { return validHyperparameters; }
protected: protected:

View File

@@ -1,3 +1,4 @@
#include <sstream>
#include "bayesnet/utils/bayesnetUtils.h" #include "bayesnet/utils/bayesnetUtils.h"
#include "Classifier.h" #include "Classifier.h"
@@ -27,10 +28,11 @@ namespace bayesnet {
dataset = torch::cat({ dataset, yresized }, 0); dataset = torch::cat({ dataset, yresized }, 0);
} }
catch (const std::exception& e) { catch (const std::exception& e) {
std::cerr << e.what() << '\n'; std::stringstream oss;
std::cout << "X dimensions: " << dataset.sizes() << "\n"; oss << "* Error in X and y dimensions *\n";
std::cout << "y dimensions: " << ytmp.sizes() << "\n"; oss << "X dimensions: " << dataset.sizes() << "\n";
exit(1); oss << "y dimensions: " << ytmp.sizes();
throw std::runtime_error(oss.str());
} }
} }
void Classifier::trainModel(const torch::Tensor& weights) void Classifier::trainModel(const torch::Tensor& weights)
@@ -73,11 +75,11 @@ namespace bayesnet {
if (torch::is_floating_point(dataset)) { if (torch::is_floating_point(dataset)) {
throw std::invalid_argument("dataset (X, y) must be of type Integer"); throw std::invalid_argument("dataset (X, y) must be of type Integer");
} }
if (n != features.size()) { if (dataset.size(0) - 1 != features.size()) {
throw std::invalid_argument("Classifier: X " + std::to_string(n) + " and features " + std::to_string(features.size()) + " must have the same number of features"); throw std::invalid_argument("Classifier: X " + std::to_string(dataset.size(0) - 1) + " and features " + std::to_string(features.size()) + " must have the same number of features");
} }
if (states.find(className) == states.end()) { if (states.find(className) == states.end()) {
throw std::invalid_argument("className not found in states"); throw std::invalid_argument("class name not found in states");
} }
for (auto feature : features) { for (auto feature : features) {
if (states.find(feature) == states.end()) { if (states.find(feature) == states.end()) {
@@ -173,12 +175,14 @@ namespace bayesnet {
{ {
return model.topological_sort(); return model.topological_sort();
} }
void Classifier::dump_cpt() const std::string Classifier::dump_cpt() const
{ {
model.dump_cpt(); return model.dump_cpt();
} }
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters) void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
{ {
//For classifiers that don't have hyperparameters if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
} }
} }

View File

@@ -30,7 +30,7 @@ namespace bayesnet {
std::vector<std::string> show() const override; std::vector<std::string> show() const override;
std::vector<std::string> topological_order() override; std::vector<std::string> topological_order() override;
std::vector<std::string> getNotes() const override { return notes; } std::vector<std::string> getNotes() const override { return notes; }
void dump_cpt() const override; std::string dump_cpt() const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters void setHyperparameters(const nlohmann::json& hyperparameters) override; //For classifiers that don't have hyperparameters
protected: protected:
bool fitted; bool fitted;

View File

@@ -6,14 +6,18 @@ namespace bayesnet {
validHyperparameters = { "k", "theta" }; validHyperparameters = { "k", "theta" };
} }
void KDB::setHyperparameters(const nlohmann::json& hyperparameters) void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
{ {
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("k")) { if (hyperparameters.contains("k")) {
k = hyperparameters["k"]; k = hyperparameters["k"];
hyperparameters.erase("k");
} }
if (hyperparameters.contains("theta")) { if (hyperparameters.contains("theta")) {
theta = hyperparameters["theta"]; theta = hyperparameters["theta"];
hyperparameters.erase("theta");
} }
Classifier::setHyperparameters(hyperparameters);
} }
void KDB::buildModel(const torch::Tensor& weights) void KDB::buildModel(const torch::Tensor& weights)
{ {

View File

@@ -14,7 +14,7 @@ namespace bayesnet {
public: public:
explicit KDB(int k, float theta = 0.03); explicit KDB(int k, float theta = 0.03);
virtual ~KDB() = default; virtual ~KDB() = default;
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(const nlohmann::json& hyperparameters_) override;
std::vector<std::string> graph(const std::string& name = "KDB") const override; std::vector<std::string> graph(const std::string& name = "KDB") const override;
}; };
} }

View File

@@ -13,9 +13,7 @@ namespace bayesnet {
predict_voting = hyperparameters["predict_voting"]; predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting"); hyperparameters.erase("predict_voting");
} }
if (!hyperparameters.empty()) { Classifier::setHyperparameters(hyperparameters);
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
} }
void AODE::buildModel(const torch::Tensor& weights) void AODE::buildModel(const torch::Tensor& weights)
{ {

View File

@@ -94,9 +94,7 @@ namespace bayesnet {
} }
hyperparameters.erase("select_features"); hyperparameters.erase("select_features");
} }
if (!hyperparameters.empty()) { Classifier::setHyperparameters(hyperparameters);
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
} }
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights) std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)
{ {

View File

@@ -20,7 +20,7 @@ namespace bayesnet {
BoostAODE(bool predict_voting = false); BoostAODE(bool predict_voting = false);
virtual ~BoostAODE() = default; virtual ~BoostAODE() = default;
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override; std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override; void setHyperparameters(const nlohmann::json& hyperparameters_) override;
protected: protected:
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override;

View File

@@ -25,8 +25,9 @@ namespace bayesnet {
{ {
return std::vector<std::string>(); return std::vector<std::string>();
} }
void dump_cpt() const override std::string dump_cpt() const override
{ {
return "";
} }
protected: protected:
torch::Tensor predict_average_voting(torch::Tensor& X); torch::Tensor predict_average_voting(torch::Tensor& X);

View File

@@ -8,12 +8,13 @@ if(ENABLE_TESTING)
${CMAKE_BINARY_DIR}/configured_files/include ${CMAKE_BINARY_DIR}/configured_files/include
) )
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc") file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES}) add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain ) target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
add_test(NAME BayesNetworkTest COMMAND TestBayesNet) add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
add_test(NAME Network COMMAND TestBayesNet "[Network]") add_test(NAME Network COMMAND TestBayesNet "[Network]")
add_test(NAME Node COMMAND TestBayesNet "[Node]") add_test(NAME Node COMMAND TestBayesNet "[Node]")
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]") add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]") add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
add_test(NAME Classifier COMMAND TestBayesNet "[Classifier]")
add_test(NAME Models COMMAND TestBayesNet "[Models]") add_test(NAME Models COMMAND TestBayesNet "[Models]")
endif(ENABLE_TESTING) endif(ENABLE_TESTING)

View File

@@ -0,0 +1,86 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include <string>
#include "TestUtils.h"
#include "bayesnet/classifiers/TAN.h"
TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
raw.yv.pop_back();
REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::runtime_error);
REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
}
TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto yshort = torch::zeros({ 149 }, torch::kInt32);
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
}
TEST_CASE("Invalid data type", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", false);
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "dataset (X, y) must be of type Integer");
}
TEST_CASE("Invalid number of features", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto Xt = torch::cat({ raw.Xt, torch::zeros({ 1, 150 }, torch::kInt32) }, 0);
REQUIRE_THROWS_AS(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest), "Classifier: X 5 and features 4 must have the same number of features");
}
TEST_CASE("Invalid class name", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, "duck", raw.statest), "class name not found in states");
}
TEST_CASE("Invalid feature name", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto statest = raw.statest;
statest.erase("petallength");
REQUIRE_THROWS_AS(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), std::invalid_argument);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, statest), "feature [petallength] not found in states");
}
TEST_CASE("Topological order", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
model.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto order = model.topological_order();
REQUIRE(order.size() == 4);
REQUIRE(order[0] == "petallength");
REQUIRE(order[1] == "sepallength");
REQUIRE(order[2] == "sepalwidth");
REQUIRE(order[3] == "petalwidth");
}
TEST_CASE("Not fitted model", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto message = "Classifier has not been fitted";
// tensors
REQUIRE_THROWS_AS(model.predict(raw.Xt), std::logic_error);
REQUIRE_THROWS_WITH(model.predict(raw.Xt), message);
REQUIRE_THROWS_AS(model.predict_proba(raw.Xt), std::logic_error);
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xt), message);
REQUIRE_THROWS_AS(model.score(raw.Xt, raw.yt), std::logic_error);
REQUIRE_THROWS_WITH(model.score(raw.Xt, raw.yt), message);
// vectors
REQUIRE_THROWS_AS(model.predict(raw.Xv), std::logic_error);
REQUIRE_THROWS_WITH(model.predict(raw.Xv), message);
REQUIRE_THROWS_AS(model.predict_proba(raw.Xv), std::logic_error);
REQUIRE_THROWS_WITH(model.predict_proba(raw.Xv), message);
REQUIRE_THROWS_AS(model.score(raw.Xv, raw.yv), std::logic_error);
REQUIRE_THROWS_WITH(model.score(raw.Xv, raw.yv), message);
}

View File

@@ -246,7 +246,7 @@ TEST_CASE("BoostAODE voting-proba", "[Models]")
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon)); REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon)); REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon)); REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
clf.dump_cpt(); REQUIRE(clf.dump_cpt() == "");
REQUIRE(clf.topological_order() == std::vector<std::string>()); REQUIRE(clf.topological_order() == std::vector<std::string>());
} }
TEST_CASE("AODE voting-proba", "[Models]") TEST_CASE("AODE voting-proba", "[Models]")