Add header to grid output and report

This commit is contained in:
Ricardo Montañana Gómez 2023-12-01 10:30:53 +01:00
parent c460ef46ed
commit 33cd32c639
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 117 additions and 23 deletions

View File

@ -63,7 +63,7 @@ namespace platform {
return Colors::RESET(); return Colors::RESET();
} }
} }
double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
{ {
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); auto [X, y] = datasets.getTensors(fileName);
@ -135,11 +135,8 @@ namespace platform {
} }
return datasets_names; return datasets_names;
} }
json GridSearch::initializeResults()
void GridSearch::go()
{ {
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
// Load previous results // Load previous results
json results; json results;
if (config.continue_from != "No") { if (config.continue_from != "No") {
@ -149,6 +146,7 @@ namespace platform {
std::ifstream file(Paths::grid_output(config.model)); std::ifstream file(Paths::grid_output(config.model));
if (file.is_open()) { if (file.is_open()) {
results = json::parse(file); results = json::parse(file);
results = results["results"];
} }
} }
catch (const std::exception& e) { catch (const std::exception& e) {
@ -157,7 +155,15 @@ namespace platform {
results = json(); results = json();
} }
} }
std::cout << "***************** Starting Gridsearch *****************" << std::endl; return results;
}
void GridSearch::goSingle()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Single Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model)); auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch // Generate hyperparameters grid & run gridsearch
@ -174,7 +180,7 @@ namespace platform {
if (!config.quiet) if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN()); showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFile(dataset, datasets, hyperparameters); double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) { if (score > bestScore) {
bestScore = score; bestScore = score;
bestHyperparameters = hyperparam_line; bestHyperparameters = hyperparam_line;
@ -184,20 +190,80 @@ namespace platform {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
} }
results[dataset]["score"] = bestScore; json result = {
results[dataset]["hyperparameters"] = bestHyperparameters; { "score", bestScore },
results[dataset]["date"] = get_date() + " " + get_time(); { "hyperparameters", bestHyperparameters },
results[dataset]["grid"] = grid.getInputGrid(dataset); { "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results // Save partial results
save(results); save(results);
} }
// Save final results // Save final results
save(results); save(results);
std::cout << "***************** Ending Gridsearch *******************" << std::endl; std::cout << "***************** Ending Single Gridsearch *******************" << std::endl;
}
void GridSearch::goNested()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl;
} }
void GridSearch::save(json& results) const void GridSearch::save(json& results) const
{ {
std::ofstream file(Paths::grid_output(config.model)); std::ofstream file(Paths::grid_output(config.model));
file << results.dump(4); json output = {
{ "model", config.model },
{ "score", config.score },
{ "discretize", config.discretize },
{ "stratified", config.stratified },
{ "n_folds", config.n_folds },
{ "seeds", config.seeds },
{ "date", get_date() + " " + get_time()},
{ "nested", config.nested},
{ "results", results }
};
file << output.dump(4);
} }
} /* namespace platform */ } /* namespace platform */

View File

@ -17,19 +17,22 @@ namespace platform {
bool only; // used with continue_from to only compute that dataset bool only; // used with continue_from to only compute that dataset
bool discretize; bool discretize;
bool stratified; bool stratified;
int nested;
int n_folds; int n_folds;
std::vector<int> seeds; std::vector<int> seeds;
}; };
class GridSearch { class GridSearch {
public: public:
explicit GridSearch(struct ConfigGrid& config); explicit GridSearch(struct ConfigGrid& config);
void go(); void goSingle();
void goNested();
~GridSearch() = default; ~GridSearch() = default;
json getResults(); json getResults();
private: private:
void save(json& results) const; void save(json& results) const;
json initializeResults();
vector<std::string> processDatasets(Datasets& datasets); vector<std::string> processDatasets(Datasets& datasets);
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
struct ConfigGrid config; struct ConfigGrid config;
}; };
} /* namespace platform */ } /* namespace platform */

View File

@ -11,6 +11,7 @@
#include "Colors.h" #include "Colors.h"
using json = nlohmann::json; using json = nlohmann::json;
const int MAXL = 133;
void manageArguments(argparse::ArgumentParser& program) void manageArguments(argparse::ArgumentParser& program)
{ {
@ -27,13 +28,14 @@ void manageArguments(argparse::ArgumentParser& program)
} }
); );
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true); group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true); group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true); group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No"); program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
try { try {
@ -83,13 +85,29 @@ void list_dump(std::string& model)
} }
std::cout << Colors::RESET() << std::endl; std::cout << Colors::RESET() << std::endl;
} }
std::string headerLine(const std::string& text, int utf = 0)
{
int n = MAXL - text.length() - 3;
n = n < 0 ? 0 : n;
return "* " + text + std::string(n + utf, ' ') + "*\n";
}
void list_results(json& results, std::string& model) void list_results(json& results, std::string& model)
{ {
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
<< model << std::endl << std::endl; std::cout << headerLine("Listing computed hyperparameters for model " + model);
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
std::cout << headerLine("Score: " + results["score"].get<std::string>());
std::cout << headerLine(
"Random seeds: " + results["seeds"].dump()
+ " Discretized: " + (results["discretize"].get<bool>() ? "True" : "False")
+ " Stratified: " + (results["stratified"].get<bool>() ? "True" : "False")
+ " #Folds: " + std::to_string(results["n_folds"].get<int>())
+ " Nested: " + (results["nested"].get<int>() == 0 ? "False" : to_string(results["nested"].get<int>()))
);
std::cout << std::string(MAXL, '*') << std::endl;
int spaces = 0; int spaces = 0;
int hyperparameters_spaces = 0; int hyperparameters_spaces = 0;
for (const auto& item : results.items()) { for (const auto& item : results["results"].items()) {
auto key = item.key(); auto key = item.key();
auto value = item.value(); auto value = item.value();
if (key.size() > spaces) { if (key.size() > spaces) {
@ -105,7 +123,7 @@ void list_results(json& results, std::string& model)
<< string(hyperparameters_spaces, '=') << std::endl; << string(hyperparameters_spaces, '=') << std::endl;
bool odd = true; bool odd = true;
int index = 0; int index = 0;
for (const auto& item : results.items()) { for (const auto& item : results["results"].items()) {
auto color = odd ? Colors::CYAN() : Colors::BLUE(); auto color = odd ? Colors::CYAN() : Colors::BLUE();
auto key = item.key(); auto key = item.key();
auto value = item.value(); auto value = item.value();
@ -119,12 +137,16 @@ void list_results(json& results, std::string& model)
std::cout << Colors::RESET() << std::endl; std::cout << Colors::RESET() << std::endl;
} }
/*
* Main
*/
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
argparse::ArgumentParser program("b_grid"); argparse::ArgumentParser program("b_grid");
manageArguments(program); manageArguments(program);
struct platform::ConfigGrid config; struct platform::ConfigGrid config;
bool dump, compute, list; bool dump, compute;
try { try {
program.parse_args(argc, argv); program.parse_args(argc, argv);
config.model = program.get<std::string>("model"); config.model = program.get<std::string>("model");
@ -135,13 +157,13 @@ int main(int argc, char** argv)
config.quiet = program.get<bool>("quiet"); config.quiet = program.get<bool>("quiet");
config.only = program.get<bool>("only"); config.only = program.get<bool>("only");
config.seeds = program.get<std::vector<int>>("seeds"); config.seeds = program.get<std::vector<int>>("seeds");
config.nested = program.get<int>("nested");
config.continue_from = program.get<std::string>("continue"); config.continue_from = program.get<std::string>("continue");
if (config.continue_from == "No" && config.only) { if (config.continue_from == "No" && config.only) {
throw std::runtime_error("Cannot use --only without --continue"); throw std::runtime_error("Cannot use --only without --continue");
} }
dump = program.get<bool>("dump"); dump = program.get<bool>("dump");
compute = program.get<bool>("compute"); compute = program.get<bool>("compute");
list = program.get<bool>("list");
if (dump && (config.continue_from != "No" || config.only)) { if (dump && (config.continue_from != "No" || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only"); throw std::runtime_error("Cannot use --dump with --continue or --only");
} }
@ -163,7 +185,10 @@ int main(int argc, char** argv)
list_dump(config.model); list_dump(config.model);
} else { } else {
if (compute) { if (compute) {
grid_search.go(); if (config.nested == 0)
grid_search.goSingle();
else
grid_search.goNested();
std::cout << "Process took " << timer.getDurationString() << std::endl; std::cout << "Process took " << timer.getDurationString() << std::endl;
} else { } else {
// List results // List results