Add excel report to manage results #6

Merged
rmontanana merged 8 commits from xlsx into main 2023-08-22 21:40:12 +00:00
3 changed files with 11 additions and 6 deletions
Showing only changes of commit 97ca8ac084 - Show all commits

View File

@ -11,12 +11,8 @@ namespace bayesnet {
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid
auto validKeys = { "repeatSparent", "maxModels", "ascending" }; const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
for (const auto& item : hyperparameters.items()) { checkHyperparameters(validKeys, hyperparameters);
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
if (hyperparameters.contains("repeatSparent")) { if (hyperparameters.contains("repeatSparent")) {
repeatSparent = hyperparameters["repeatSparent"]; repeatSparent = hyperparameters["repeatSparent"];
} }

View File

@ -152,4 +152,12 @@ namespace bayesnet {
{ {
model.dump_cpt(); model.dump_cpt();
} }
void Classifier::checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters)
{
for (const auto& item : hyperparameters.items()) {
if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) {
throw invalid_argument("Hyperparameter " + item.key() + " is not valid");
}
}
}
} }

View File

@ -24,6 +24,7 @@ namespace bayesnet {
void checkFitParameters(); void checkFitParameters();
virtual void buildModel(const torch::Tensor& weights) = 0; virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
public: public:
Classifier(Network model); Classifier(Network model);
virtual ~Classifier() = default; virtual ~Classifier() = default;