Refactor hyperparameters classifier management

This commit is contained in:
Ricardo Montañana Gómez 2024-04-08 00:55:30 +02:00
parent 0d6a081d01
commit 9014649a0d
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 43 additions and 15 deletions

View File

@ -1,3 +1,4 @@
#include <sstream>
#include "bayesnet/utils/bayesnetUtils.h"
#include "Classifier.h"
@ -27,10 +28,11 @@ namespace bayesnet {
dataset = torch::cat({ dataset, yresized }, 0);
}
catch (const std::exception& e) {
std::cerr << e.what() << '\n';
std::cout << "X dimensions: " << dataset.sizes() << "\n";
std::cout << "y dimensions: " << ytmp.sizes() << "\n";
exit(1);
std::stringstream oss;
oss << "* Error in X and y dimensions *\n";
oss << "X dimensions: " << dataset.sizes() << "\n";
oss << "y dimensions: " << ytmp.sizes();
throw std::runtime_error(oss.str());
}
}
void Classifier::trainModel(const torch::Tensor& weights)
@ -179,6 +181,8 @@ namespace bayesnet {
}
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
{
//For classifiers that don't have hyperparameters
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
}

View File

@ -6,14 +6,18 @@ namespace bayesnet {
validHyperparameters = { "k", "theta" };
}
void KDB::setHyperparameters(const nlohmann::json& hyperparameters)
void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("k")) {
k = hyperparameters["k"];
hyperparameters.erase("k");
}
if (hyperparameters.contains("theta")) {
theta = hyperparameters["theta"];
hyperparameters.erase("theta");
}
Classifier::setHyperparameters(hyperparameters);
}
void KDB::buildModel(const torch::Tensor& weights)
{

View File

@ -14,7 +14,7 @@ namespace bayesnet {
public:
explicit KDB(int k, float theta = 0.03);
virtual ~KDB() = default;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
std::vector<std::string> graph(const std::string& name = "KDB") const override;
};
}

View File

@ -13,9 +13,7 @@ namespace bayesnet {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
Classifier::setHyperparameters(hyperparameters);
}
void AODE::buildModel(const torch::Tensor& weights)
{

View File

@ -94,9 +94,7 @@ namespace bayesnet {
}
hyperparameters.erase("select_features");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
Classifier::setHyperparameters(hyperparameters);
}
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)
{

View File

@ -20,7 +20,7 @@ namespace bayesnet {
BoostAODE(bool predict_voting = false);
virtual ~BoostAODE() = default;
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights) override;

View File

@ -8,12 +8,13 @@ if(ENABLE_TESTING)
${CMAKE_BINARY_DIR}/configured_files/include
)
file(GLOB_RECURSE BayesNet_SOURCES "${BayesNet_SOURCE_DIR}/bayesnet/*.cc")
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
add_executable(TestBayesNet TestBayesNetwork.cc TestBayesNode.cc TestBayesClassifier.cc TestBayesModels.cc TestBayesMetrics.cc TestFeatureSelection.cc TestUtils.cc ${BayesNet_SOURCES})
target_link_libraries(TestBayesNet PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain )
add_test(NAME BayesNetworkTest COMMAND TestBayesNet)
add_test(NAME Network COMMAND TestBayesNet "[Network]")
add_test(NAME Node COMMAND TestBayesNet "[Node]")
add_test(NAME Metrics COMMAND TestBayesNet "[Metrics]")
add_test(NAME FeatureSelection COMMAND TestBayesNet "[FeatureSelection]")
add_test(NAME Classifier COMMAND TestBayesNet "[Classifier]")
add_test(NAME Models COMMAND TestBayesNet "[Models]")
endif(ENABLE_TESTING)

View File

@ -0,0 +1,23 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include <string>
#include "TestUtils.h"
#include "bayesnet/classifiers/TAN.h"
TEST_CASE("Test Cannot build dataset with wrong data vector", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
raw.yv.pop_back();
REQUIRE_THROWS_AS(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), std::runtime_error);
REQUIRE_THROWS_WITH(model.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
}
TEST_CASE("Test Cannot build dataset with wrong data tensor", "[Classifier]")
{
auto model = bayesnet::TAN();
auto raw = RawDatasets("iris", true);
auto yshort = torch::zeros({ 149 }, torch::kInt32);
REQUIRE_THROWS_AS(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), std::runtime_error);
REQUIRE_THROWS_WITH(model.fit(raw.Xt, yshort, raw.featurest, raw.classNamet, raw.statest), "* Error in X and y dimensions *\nX dimensions: [4, 150]\ny dimensions: [149]");
}