From c713c0b1df196483a849097ff25818090f429ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 26 Nov 2023 10:36:09 +0100 Subject: [PATCH] Add continue from parameter to gridsearch --- src/Platform/GridSearch.cc | 60 ++++++++++++++++++++++++++++++++++---- src/Platform/GridSearch.h | 4 +-- src/Platform/Result.h | 1 - src/Platform/Results.h | 2 -- src/Platform/b_grid.cc | 3 +- 5 files changed, 58 insertions(+), 12 deletions(-) diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index fd2cbbe..85dcd26 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -7,6 +7,26 @@ #include "Colors.h" namespace platform { + std::string get_date() + { + time_t rawtime; + tm* timeinfo; + time(&rawtime); + timeinfo = std::localtime(&rawtime); + std::ostringstream oss; + oss << std::put_time(timeinfo, "%Y-%m-%d"); + return oss.str(); + } + std::string get_time() + { + time_t rawtime; + tm* timeinfo; + time(&rawtime); + timeinfo = std::localtime(&rawtime); + std::ostringstream oss; + oss << std::put_time(timeinfo, "%H:%M:%S"); + return oss.str(); + } GridSearch::GridSearch(struct ConfigGrid& config) : config(config) { this->config.output_file = config.path + "grid_" + config.model + "_output.json"; @@ -43,7 +63,6 @@ namespace platform { auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); auto features = datasets.getFeatures(fileName); - auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); double totalScore = 0.0; int numItems = 0; @@ -86,6 +105,33 @@ namespace platform { { // Load datasets auto datasets = Datasets(config.discretize, Paths::datasets()); + // Load previous results + json results; + auto datasets_names = datasets.getNames(); + if (config.continue_from != "no") { + if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { + throw std::invalid_argument("Dataset " + config.continue_from + " not found"); + } + if (!config.quiet) + std::cout << "* Loading previous results" << std::endl; + try { + std::ifstream file(config.output_file); + if (file.is_open()) { + results = json::parse(file); + } + } + catch (const std::exception& e) { + std::cerr << "Error loading previous results: " << e.what() << std::endl; + } + // Remove datasets already processed + vector< string >::iterator it = datasets_names.begin(); + while (it != datasets_names.end()) { + if (*it != config.continue_from) { + it = datasets_names.erase(it); + } else + break; + } + } // Create model std::cout << "***************** Starting Gridsearch *****************" << std::endl; std::cout << "input file=" << config.input_file << std::endl; @@ -94,7 +140,7 @@ namespace platform { std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl; // Generate hyperparameters grid & run gridsearch // Check each combination of hyperparameters for each dataset and each seed - for (const auto& dataset : datasets.getNames()) { + for (const auto& dataset : datasets_names) { if (!config.quiet) std::cout << "- " << setw(20) << left << dataset << " " << right << flush; int num = 0; @@ -116,15 +162,17 @@ namespace platform { } results[dataset]["score"] = bestScore; results[dataset]["hyperparameters"] = bestHyperparameters; + results[dataset]["date"] = get_date() + " " + get_time(); + // Save partial results + save(results); } - // Save results - save(); + // Save final results + save(results); std::cout << "***************** Ending Gridsearch *******************" << std::endl; } - void GridSearch::save() const + void GridSearch::save(json& results) const { std::ofstream file(config.output_file); file << results.dump(4); - file.close(); } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index c5528d0..226fd27 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -15,6 +15,7 @@ namespace platform { std::string path; std::string input_file; std::string output_file; + std::string continue_from; bool quiet; bool discretize; bool stratified; @@ -25,11 +26,10 @@ namespace platform { public: explicit GridSearch(struct ConfigGrid& config); void go(); - void save() const; ~GridSearch() = default; private: + void save(json& results) const; double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); - json results; struct ConfigGrid config; }; } /* namespace platform */ diff --git a/src/Platform/Result.h b/src/Platform/Result.h index 85ec832..10459b7 100644 --- a/src/Platform/Result.h +++ b/src/Platform/Result.h @@ -32,5 +32,4 @@ namespace platform { bool complete; }; }; - #endif \ No newline at end of file diff --git a/src/Platform/Results.h b/src/Platform/Results.h index aa293d8..9f9023f 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -7,7 +7,6 @@ #include "Result.h" namespace platform { using json = nlohmann::json; - class Results { public: Results(const std::string& path, const std::string& model, const std::string& score, bool complete, bool partial); @@ -34,5 +33,4 @@ namespace platform { void load(); // Loads the list of results }; }; - #endif \ No newline at end of file diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 8887892..8dcb12f 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -24,6 +24,7 @@ argparse::ArgumentParser manageArguments(std::string program_name) ); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); + program.add_argument("--continue").help("Continue computing from that dataset").default_value("No"); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { @@ -58,6 +59,7 @@ int main(int argc, char** argv) config.n_folds = program.get("folds"); config.quiet = program.get("quiet"); config.seeds = program.get>("seeds"); + config.continue_from = program.get("continue"); } catch (const exception& err) { cerr << err.what() << std::endl; @@ -75,7 +77,6 @@ int main(int argc, char** argv) timer.start(); grid_search.go(); std::cout << "Process took " << timer.getDurationString() << std::endl; - grid_search.save(); std::cout << "Done!" << std::endl; return 0; }