Refactor hyperparameters classifier management

This commit is contained in:
2024-04-08 00:55:30 +02:00
parent 0d6a081d01
commit 9014649a0d
8 changed files with 43 additions and 15 deletions

View File

@@ -1,3 +1,4 @@
#include <sstream>
#include "bayesnet/utils/bayesnetUtils.h"
#include "Classifier.h"
@@ -27,10 +28,11 @@ namespace bayesnet {
dataset = torch::cat({ dataset, yresized }, 0);
}
catch (const std::exception& e) {
std::cerr << e.what() << '\n';
std::cout << "X dimensions: " << dataset.sizes() << "\n";
std::cout << "y dimensions: " << ytmp.sizes() << "\n";
exit(1);
std::stringstream oss;
oss << "* Error in X and y dimensions *\n";
oss << "X dimensions: " << dataset.sizes() << "\n";
oss << "y dimensions: " << ytmp.sizes();
throw std::runtime_error(oss.str());
}
}
void Classifier::trainModel(const torch::Tensor& weights)
@@ -179,6 +181,8 @@ namespace bayesnet {
}
void Classifier::setHyperparameters(const nlohmann::json& hyperparameters)
{
//For classifiers that don't have hyperparameters
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
}