From 8dbbb65a2fcd837bdbaea678f7cab38ac73ecf1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 28 Nov 2023 10:08:40 +0100 Subject: [PATCH] Add only parameter to gridsearch --- .vscode/launch.json | 25 ++++++++++++++++++++----- src/Platform/GridSearch.cc | 12 +++++++++--- src/Platform/GridSearch.h | 1 + src/Platform/b_grid.cc | 7 ++++++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 89beff3..0ea8fbd 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,7 +14,7 @@ "-s", "271", "-p", - "/home/rmontanana/Code/discretizbench/datasets/", + "/Users/rmontanana/Code/discretizbench/datasets/", ], //"cwd": "${workspaceFolder}/build/sample/", }, @@ -33,7 +33,22 @@ // "--hyperparameters", // "{\"repeatSparent\": true, \"maxModels\": 12}" ], - "cwd": "/home/rmontanana/Code/discretizbench", + "cwd": "${workspaceFolder}/../discretizbench", + }, + { + "type": "lldb", + "request": "launch", + "name": "gridsearch", + "program": "${workspaceFolder}/build_debug/src/Platform/b_grid", + "args": [ + "-m", + "KDB", + "--discretize", + "--continue", + "glass", + "--only" + ], + "cwd": "${workspaceFolder}/../discretizbench", }, { "type": "lldb", @@ -64,7 +79,7 @@ "accuracy", "--build", ], - "cwd": "/home/rmontanana/Code/discretizbench", + "cwd": "${workspaceFolder}/../discretizbench", }, { "type": "lldb", @@ -75,7 +90,7 @@ "-n", "20" ], - "cwd": "/home/rmontanana/Code/discretizbench", + "cwd": "${workspaceFolder}/../discretizbench", }, { "type": "lldb", @@ -84,7 +99,7 @@ "program": "${workspaceFolder}/build_debug/src/Platform/b_list", "args": [], //"cwd": "/Users/rmontanana/Code/discretizbench", - "cwd": "/home/rmontanana/Code/covbench", + "cwd": "${workspaceFolder}/../discretizbench", }, { "type": "lldb", diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 2546131..2074514 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -123,15 +123,21 @@ namespace platform { } } catch (const std::exception& e) { - std::cerr << "Error loading previous results: " << e.what() << std::endl; + std::cerr << "* There were no previous results" << std::endl; + std::cerr << "* Initizalizing new results" << std::endl; + results = json(); } // Remove datasets already processed vector< string >::iterator it = datasets_names.begin(); while (it != datasets_names.end()) { if (*it != config.continue_from) { it = datasets_names.erase(it); - } else - break; + } else { + if (config.only) + ++it; + else + break; + } } } // Create model diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 226fd27..b52c1e9 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -17,6 +17,7 @@ namespace platform { std::string output_file; std::string continue_from; bool quiet; + bool only; // used with continue_from to only compute that dataset bool discretize; bool stratified; int n_folds; diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 8dcb12f..6a63332 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -23,9 +23,10 @@ argparse::ArgumentParser manageArguments(std::string program_name) } ); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--continue").help("Continue computing from that dataset").default_value("No"); - program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); + program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { try { @@ -58,8 +59,12 @@ int main(int argc, char** argv) config.stratified = program.get("stratified"); config.n_folds = program.get("folds"); config.quiet = program.get("quiet"); + config.only = program.get("only"); config.seeds = program.get>("seeds"); config.continue_from = program.get("continue"); + if (config.continue_from == "No" && config.only) { + throw std::runtime_error("Cannot use --only without --continue"); + } } catch (const exception& err) { cerr << err.what() << std::endl;