Add excel report to manage results #6

Merged
rmontanana merged 8 commits from xlsx into main 2023-08-22 21:40:12 +00:00
3 changed files with 11 additions and 6 deletions
Showing only changes of commit 97ca8ac084 - Show all commits

View File

@ -11,12 +11,8 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
auto validKeys = { "repeatSparent", "maxModels", "ascending" };
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
checkHyperparameters(validKeys, hyperparameters);
if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"];
}

View File

@ -152,4 +152,12 @@ namespace bayesnet {
{
model.dump_cpt();
}
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
{
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
}
}

View File

@ -24,6 +24,7 @@ namespace bayesnet {
void checkFitParameters();
virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
public:
Classifier(Network model);
virtual ~Classifier() = default;