From 9014649a0dba3ae51942f74cad187480d1e78a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 8 Apr 2024 00:55:30 +0200 Subject: [PATCH] Refactor hyperparameters classifier management --- bayesnet/classifiers/Classifier.cc | 14 +++++++++----- bayesnet/classifiers/KDB.cc | 6 +++++- bayesnet/classifiers/KDB.h | 2 +- bayesnet/ensembles/AODE.cc | 4 +--- bayesnet/ensembles/BoostAODE.cc | 4 +--- bayesnet/ensembles/BoostAODE.h | 2 +- tests/CMakeLists.txt | 3 ++- tests/TestBayesClassifier.cc | 23 +++++++++++++++++++++++ 8 files changed, 43 insertions(+), 15 deletions(-) create mode 100644 tests/TestBayesClassifier.cc diff --git a/bayesnet/classifiers/Classifier.cc b/bayesnet/classifiers/Classifier.cc index eed8d91..19f337a 100644 --- a/bayesnet/classifiers/Classifier.cc +++ b/bayesnet/classifiers/Classifier.cc @@ -1,3 +1,4 @@ +#include #include "bayesnet/utils/bayesnetUtils.h" #include "Classifier.h" @@ -27,10 +28,11 @@ namespace bayesnet { dataset = torch::cat({ dataset, yresized }, 0); } catch (const std::exception& e) { - std::cerr << e.what() << '\n'; - std::cout << "X dimensions: " << dataset.sizes() << "\n"; - std::cout << "y dimensions: " << ytmp.sizes() << "\n"; - exit(1); + std::stringstream oss; + oss << "* Error in X and y dimensions *\n"; + oss << "X dimensions: " << dataset.sizes() << "\n"; + oss << "y dimensions: " << ytmp.sizes(); + throw std::runtime_error(oss.str()); } } void Classifier::trainModel(const torch::Tensor& weights) @@ -179,6 +181,8 @@ namespace bayesnet { } void Classifier::setHyperparameters(const nlohmann::json& hyperparameters) { - //For classifiers that don't have hyperparameters + if (!hyperparameters.empty()) { + throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); + } } } \ No newline at end of file diff --git a/bayesnet/classifiers/KDB.cc b/bayesnet/classifiers/KDB.cc index 7781ca0..6c4bb99 100644 --- a/bayesnet/classifiers/KDB.cc +++ b/bayesnet/classifiers/KDB.cc @@ -6,14 +6,18 @@ namespace bayesnet { validHyperparameters = { "k", "theta" }; } - void KDB::setHyperparameters(const nlohmann::json& hyperparameters) + void KDB::setHyperparameters(const nlohmann::json& hyperparameters_) { + auto hyperparameters = hyperparameters_; if (hyperparameters.contains("k")) { k = hyperparameters["k"]; + hyperparameters.erase("k"); } if (hyperparameters.contains("theta")) { theta = hyperparameters["theta"]; + hyperparameters.erase("theta"); } + Classifier::setHyperparameters(hyperparameters); } void KDB::buildModel(const torch::Tensor& weights) { diff --git a/bayesnet/classifiers/KDB.h b/bayesnet/classifiers/KDB.h index 9478475..17c2a1f 100644 --- a/bayesnet/classifiers/KDB.h +++ b/bayesnet/classifiers/KDB.h @@ -14,7 +14,7 @@ namespace bayesnet { public: explicit KDB(int k, float theta = 0.03); virtual ~KDB() = default; - void setHyperparameters(const nlohmann::json& hyperparameters) override; + void setHyperparameters(const nlohmann::json& hyperparameters_) override; std::vector graph(const std::string& name = "KDB") const override; }; } diff --git a/bayesnet/ensembles/AODE.cc b/bayesnet/ensembles/AODE.cc index f984f9d..22b17b8 100644 --- a/bayesnet/ensembles/AODE.cc +++ b/bayesnet/ensembles/AODE.cc @@ -13,9 +13,7 @@ namespace bayesnet { predict_voting = hyperparameters["predict_voting"]; hyperparameters.erase("predict_voting"); } - if (!hyperparameters.empty()) { - throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); - } + Classifier::setHyperparameters(hyperparameters); } void AODE::buildModel(const torch::Tensor& weights) { diff --git a/bayesnet/ensembles/BoostAODE.cc b/bayesnet/ensembles/BoostAODE.cc index 8426638..9e4a856 100644 --- a/bayesnet/ensembles/BoostAODE.cc +++ b/bayesnet/ensembles/BoostAODE.cc @@ -94,9 +94,7 @@ namespace bayesnet { } hyperparameters.erase("select_features"); } - if (!hyperparameters.empty()) { - throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump()); - } + Classifier::setHyperparameters(hyperparameters); } std::tuple update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights) { diff --git a/bayesnet/ensembles/BoostAODE.h b/bayesnet/ensembles/BoostAODE.h index dc074fb..f4091df 100644 --- a/bayesnet/ensembles/BoostAODE.h +++ b/bayesnet/ensembles/BoostAODE.h @@ -20,7 +20,7 @@ namespace bayesnet { BoostAODE(bool predict_voting = false); virtual ~BoostAODE() = default; std::vector graph(const std::string& title = "BoostAODE") const override; - void setHyperparameters(const nlohmann::json& hyperparameters) override; + void setHyperparameters(const nlohmann::json& hyperparameters_) override; protected: void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0f0abf5..02fd775 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,12 +8,13 @@ if(ENABLE_TESTING) ${CMAKE_BINARY_DIR}/configured_files/include ) file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc") - add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES}) + add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES}) target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain ) add_test(NAME BayesNetworkTest COMMAND TestBayesNet) add_test(NAME Network COMMAND TestBayesNet "[Network]") add_test(NAME Node COMMAND TestBayesNet "[Node]") add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]") add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]") + add_test(NAME Classifier COMMAND TestBayesNet "[Classifier]") add_test(NAME Models COMMAND TestBayesNet "[Models]") endif(ENABLE_TESTING) diff --git a/tests/TestBayesClassifier.cc b/tests/TestBayesClassifier.cc new file mode 100644 index 0000000..d07adbd --- /dev/null +++ b/tests/TestBayesClassifier.cc @@ -0,0 +1,23 @@ +#include +#include +#include +#include "TestUtils.h" +#include "bayesnet/classifiers/TAN.h" + + +TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + raw.yv.pop_back(); + REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::runtime_error); + REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]"); +} +TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]") +{ + auto model = bayesnet::TAN(); + auto raw = RawDatasets("iris", true); + auto yshort = torch::zeros({ 149 }, torch::kInt32); + REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error); + REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]"); +} \ No newline at end of file