Support b_main with best hyperparameters

This commit is contained in:
2024-08-02 19:10:25 +02:00
parent 97abec8b69
commit 0ea967dd9d
6 changed files with 36 additions and 17 deletions

View File

@@ -6,6 +6,7 @@
#include <algorithm>
#include "common/Colors.h"
#include "common/CLocale.h"
#include "common/Paths.h"
#include "results/Result.h"
#include "BestResultsExcel.h"
#include "best/Statistics.h"
@@ -59,16 +60,12 @@ namespace platform {
std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl;
exit(1);
}
std::string bestFileName = path + bestResultFile();
std::string bestFileName = path + Paths::bestResultsFile(score, model);
std::ofstream file(bestFileName);
file << bests;
file.close();
return bestFileName;
}
std::string BestResults::bestResultFile()
{
return "best_results_" + score + "_" + model + ".json";
}
std::pair<std::string, std::string> getModelScore(std::string name)
{
// results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json
@@ -150,7 +147,7 @@ namespace platform {
}
void BestResults::listFile()
{
std::string bestFileName = path + bestResultFile();
std::string bestFileName = path + Paths::bestResultsFile(score, model);
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
fclose(fileTest);
} else {
@@ -196,7 +193,7 @@ namespace platform {
auto maxDate = std::filesystem::file_time_type::max();
for (const auto& model : models) {
this->model = model;
std::string bestFileName = path + bestResultFile();
std::string bestFileName = path + Paths::bestResultsFile(score, model);
if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) {
fclose(fileTest);
} else {
@@ -306,7 +303,7 @@ namespace platform {
json table = buildTableResults(models);
std::vector<std::string> datasets = getDatasets(table.begin().value());
BestResultsExcel excel_report(score, datasets);
excel_report.reportSingle(model, path + bestResultFile());
excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model));
messageExcelFile(excel_report.getFileName());
}
}
@@ -346,7 +343,7 @@ namespace platform {
}
}
model = models.at(idx);
excel.reportSingle(model, path + bestResultFile());
excel.reportSingle(model, path + Paths::bestResultsFile(score, model));
}
messageExcelFile(excel.getFileName());
}

View File

@@ -22,7 +22,6 @@ namespace platform {
void messageExcelFile(const std::string& fileName);
json buildTableResults(std::vector<std::string> models);
void printTableResults(std::vector<std::string> models, json table);
std::string bestResultFile();
json loadFile(const std::string& fileName);
void listFile();
std::string path;

View File

@@ -36,6 +36,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true);
program.add_argument("-m", "--model")
.help("Model to use: " + platform::Models::instance()->toString())
.action([](const std::string& value) {
@@ -93,7 +94,7 @@ int main(int argc, char** argv)
manageArguments(program);
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best;
std::vector<int> seeds;
std::vector<std::string> file_names;
std::vector<std::string> filesToTest;
@@ -117,9 +118,17 @@ int main(int argc, char** argv)
hyperparameters_json = json::parse(hyperparameters);
hyperparameters_file = program.get<std::string>("hyper-file");
no_train_score = program.get<bool>("no-train-score");
hyper_best = program.get<bool>("hyper-best");
generate_fold_files = program.get<bool>("generate-fold-files");
if (hyperparameters_file != "" && hyperparameters != "{}") {
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
if (hyper_best) {
// Build the best results file_name
hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name);
// ignore this parameter
hyperparameters = "{}";
} else {
if (hyperparameters_file != "" && hyperparameters != "{}") {
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
}
}
title = program.get<std::string>("title");
if (title == "" && file_name == "all") {
@@ -188,7 +197,7 @@ int main(int argc, char** argv)
platform::HyperParameters test_hyperparams;
if (hyperparameters_file != "") {
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file);
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best);
} else {
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json);
}

View File

@@ -32,6 +32,10 @@ namespace platform {
throw std::runtime_error("Could not create directory " + path);
}
}
static std::string bestResultsFile(const std::string& score, const std::string& model)
{
return "best_results_" + score + "_" + model + ".json";
}
static std::string excelResults() { return "some_results.xlsx"; }
static std::string grid_input(const std::string& model)
{

View File

@@ -19,7 +19,7 @@ namespace platform {
std::ostream_iterator<std::string>(ss, delim.c_str()));
return ss.str();
}
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file)
HyperParameters::HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best)
{
// Check if file exists
std::ifstream file(hyperparameters_file);
@@ -28,7 +28,14 @@ namespace platform {
}
// Check if file is a json
json file_hyperparameters = json::parse(file);
auto input_hyperparameters = file_hyperparameters["results"];
json input_hyperparameters;
if (best) {
for (const auto& [key, value] : file_hyperparameters.items()) {
input_hyperparameters[key] = value[1];
}
} else {
input_hyperparameters = file_hyperparameters["results"];
}
// Check if hyperparameters are valid
for (const auto& dataset : datasets) {
if (!input_hyperparameters.contains(dataset)) {

View File

@@ -10,14 +10,17 @@ namespace platform {
class HyperParameters {
public:
HyperParameters() = default;
// Constructor to use command line hyperparameters
explicit HyperParameters(const std::vector<std::string>& datasets, const json& hyperparameters_);
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file);
// Constructor to use hyperparameters file generated by grid or by best results
explicit HyperParameters(const std::vector<std::string>& datasets, const std::string& hyperparameters_file, bool best = false);
~HyperParameters() = default;
bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); }
void check(const std::vector<std::string>& valid, const std::string& fileName);
json get(const std::string& fileName);
private:
std::map<std::string, json> hyperparameters;
bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different
};
} /* namespace platform */
#endif