refactor gridsearch to have only one go method

This commit is contained in:
Ricardo Montañana Gómez 2023-12-02 10:59:05 +01:00
parent 33cd32c639
commit 03e4437fea
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 176 additions and 137 deletions

3
.vscode/launch.json vendored
View File

@ -46,7 +46,8 @@
"--discretize", "--discretize",
"--continue", "--continue",
"glass", "glass",
"--only" "--only",
"--compute"
], ],
"cwd": "${workspaceFolder}/../discretizbench", "cwd": "${workspaceFolder}/../discretizbench",
}, },

View File

@ -38,10 +38,10 @@ namespace platform {
} }
return json(); return json();
} }
void showProgressComb(const int num, const int total, const std::string& color) void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
{ {
int spaces = int(log(total) / log(10)) + 1; int spaces = int(log(total) / log(10)) + 1;
int magic = 37 + 2 * spaces; int magic = n_folds * 3 + 22 + 2 * spaces;
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b'); std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush; std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
} }
@ -63,18 +63,120 @@ namespace platform {
return Colors::RESET(); return Colors::RESET();
} }
} }
double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) void GridSearch::go()
{
timer.start();
auto grid_type = config.nested == 0 ? "Single" : "Nested";
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting " << grid_type << " Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
Timer timer_dataset;
double bestScore = 0;
json bestHyperparameters;
for (const auto& dataset : datasets_names) {
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
auto combinations = grid.getGrid(dataset);
timer_dataset.start();
if (config.nested == 0)
// for dataset // for hyperparameters // for seed // for fold
tie(bestScore, bestHyperparameters) = processFileSingle(dataset, datasets, combinations);
else
// for dataset // for seed // for fold // for hyperparameters // for nested fold
tie(bestScore, bestHyperparameters) = processFileNested(dataset, datasets, combinations);
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) },
{ "duration", timer_dataset.getDurationString() }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending " << grid_type << " Gridsearch *******************" << std::endl;
}
pair<double, json> GridSearch::processFileSingle(std::string fileName, Datasets& datasets, vector<json>& combinations)
{
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto totalComb = combinations.size();
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, config.n_folds, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto className = datasets.getClassName(fileName);
double totalScore = 0.0;
int numItems = 0;
for (const auto& seed : config.seeds) {
if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed);
else
fold = new KFold(config.n_folds, y.size(0), seed);
for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
auto X_train = X.index({ "...", train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_train, y_train, features, className, states);
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test);
numItems++;
if (!config.quiet)
std::cout << "\b\b\b, \b" << flush;
}
delete fold;
}
double score = numItems == 0 ? 0.0 : totalScore / numItems;
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
return { bestScore, bestHyperparameters };
}
pair<double, json> GridSearch::processFileNested(std::string fileName, Datasets& datasets, vector<json>& combinations)
{ {
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName); auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
double totalScore = 0.0; double bestScore = 0.0;
json bestHyperparameters;
int numItems = 0; int numItems = 0;
// for dataset // for seed // for fold // for hyperparameters // for nested fold
for (const auto& seed : config.seeds) { for (const auto& seed : config.seeds) {
if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold; Fold* fold;
if (config.stratified) if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed); fold = new StratifiedKFold(config.n_folds, y, seed);
@ -82,10 +184,7 @@ namespace platform {
fold = new KFold(config.n_folds, y.size(0), seed); fold = new KFold(config.n_folds, y.size(0), seed);
double bestScore = 0.0; double bestScore = 0.0;
for (int nfold = 0; nfold < config.n_folds; nfold++) { for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model); // First level fold
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
auto [train, test] = fold->getFold(nfold); auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train); auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test); auto test_t = torch::tensor(test);
@ -93,28 +192,50 @@ namespace platform {
auto y_train = y.index({ train_t }); auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t }); auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t }); auto y_test = y.index({ test_t });
// Train model for (const auto& hyperparam_line : combinations) {
if (!config.quiet) Fold* nested_fold;
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); if (config.stratified)
clf->fit(X_train, y_train, features, className, states); nested_fold = new StratifiedKFold(config.nested, y_train, seed);
// Test model else
if (!config.quiet) nested_fold = new KFold(config.nested, y_train.size(0), seed);
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test); for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
numItems++; // Nested level fold
if (!config.quiet) auto [train_nested, test_nested] = fold->getFold(n_nested_fold);
std::cout << "\b\b\b, \b" << flush; auto train_nested_t = torch::tensor(train_nested);
auto test_nested_t = torch::tensor(test_nested);
auto X_nexted_train = X_train.index({ "...", train_nested_t });
auto y_nested_train = y_train.index({ train_nested_t });
auto X_nested_test = X_train.index({ "...", test_nested_t });
auto y_nested_test = y_train.index({ test_nested_t });
// Build Classifier with selected hyperparameters
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_nexted_train, y_nested_train, features, className, states);
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
bestScore += clf->score(X_nested_test, y_nested_test);
}
delete nested_fold;
}
} }
delete fold; delete fold;
} }
return numItems == 0 ? 0.0 : totalScore / numItems; return { bestScore, bestHyperparameters };
} }
vector<std::string> GridSearch::processDatasets(Datasets& datasets) vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{ {
// Load datasets // Load datasets
auto datasets_names = datasets.getNames(); auto datasets_names = datasets.getNames();
if (config.continue_from != "No") { if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution: // Continue previous execution:
// remove datasets already processed // remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) { if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
@ -139,7 +260,7 @@ namespace platform {
{ {
// Load previous results // Load previous results
json results; json results;
if (config.continue_from != "No") { if (config.continue_from != NO_CONTINUE()) {
if (!config.quiet) if (!config.quiet)
std::cout << "* Loading previous results" << std::endl; std::cout << "* Loading previous results" << std::endl;
try { try {
@ -157,100 +278,7 @@ namespace platform {
} }
return results; return results;
} }
void GridSearch::save(json& results)
void GridSearch::goSingle()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Single Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending Single Gridsearch *******************" << std::endl;
}
void GridSearch::goNested()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl;
}
void GridSearch::save(json& results) const
{ {
std::ofstream file(Paths::grid_output(config.model)); std::ofstream file(Paths::grid_output(config.model));
json output = { json output = {
@ -262,7 +290,10 @@ namespace platform {
{ "seeds", config.seeds }, { "seeds", config.seeds },
{ "date", get_date() + " " + get_time()}, { "date", get_date() + " " + get_time()},
{ "nested", config.nested}, { "nested", config.nested},
{ "platform", config.platform },
{ "duration", timer.getDurationString(true)},
{ "results", results } { "results", results }
}; };
file << output.dump(4); file << output.dump(4);
} }

View File

@ -6,6 +6,7 @@
#include "Datasets.h" #include "Datasets.h"
#include "HyperParameters.h" #include "HyperParameters.h"
#include "GridData.h" #include "GridData.h"
#include "Timer.h"
namespace platform { namespace platform {
using json = nlohmann::json; using json = nlohmann::json;
@ -13,6 +14,7 @@ namespace platform {
std::string model; std::string model;
std::string score; std::string score;
std::string continue_from; std::string continue_from;
std::string platform;
bool quiet; bool quiet;
bool only; // used with continue_from to only compute that dataset bool only; // used with continue_from to only compute that dataset
bool discretize; bool discretize;
@ -24,16 +26,18 @@ namespace platform {
class GridSearch { class GridSearch {
public: public:
explicit GridSearch(struct ConfigGrid& config); explicit GridSearch(struct ConfigGrid& config);
void goSingle(); void go();
void goNested();
~GridSearch() = default; ~GridSearch() = default;
json getResults(); json getResults();
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
private: private:
void save(json& results) const; void save(json& results);
json initializeResults(); json initializeResults();
vector<std::string> processDatasets(Datasets& datasets); vector<std::string> processDatasets(Datasets& datasets);
double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); pair<double, json> processFileSingle(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
pair<double, json> processFileNested(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
struct ConfigGrid config; struct ConfigGrid config;
Timer timer; // used to measure the time of the whole process
}; };
} /* namespace platform */ } /* namespace platform */
#endif /* GRIDSEARCH_H */ #endif /* GRIDSEARCH_H */

View File

@ -20,9 +20,14 @@ namespace platform {
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin); std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
return time_span.count(); return time_span.count();
} }
std::string getDurationString() double getLapse()
{ {
double duration = getDuration(); std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (std::chrono::high_resolution_clock::now() - begin);
return time_span.count();
}
std::string getDurationString(bool lapse = false)
{
double duration = lapse ? getLapse() : getDuration();
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration; double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s"; std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
std::stringstream ss; std::stringstream ss;

View File

@ -33,7 +33,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No"); program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
@ -95,7 +95,7 @@ void list_results(json& results, std::string& model)
{ {
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl; std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
std::cout << headerLine("Listing computed hyperparameters for model " + model); std::cout << headerLine("Listing computed hyperparameters for model " + model);
std::cout << headerLine("Date & time: " + results["date"].get<std::string>()); std::cout << headerLine("Date & time: " + results["date"].get<std::string>() + " Duration: " + results["duration"].get<std::string>());
std::cout << headerLine("Score: " + results["score"].get<std::string>()); std::cout << headerLine("Score: " + results["score"].get<std::string>());
std::cout << headerLine( std::cout << headerLine(
"Random seeds: " + results["seeds"].dump() "Random seeds: " + results["seeds"].dump()
@ -118,9 +118,9 @@ void list_results(json& results, std::string& model)
} }
} }
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " " std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl; << "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " " std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl; << string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
bool odd = true; bool odd = true;
int index = 0; int index = 0;
for (const auto& item : results["results"].items()) { for (const auto& item : results["results"].items()) {
@ -130,8 +130,8 @@ void list_results(json& results, std::string& model)
std::cout << color; std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " "; std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>() std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right << " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl; << fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd; odd = !odd;
} }
std::cout << Colors::RESET() << std::endl; std::cout << Colors::RESET() << std::endl;
@ -159,12 +159,12 @@ int main(int argc, char** argv)
config.seeds = program.get<std::vector<int>>("seeds"); config.seeds = program.get<std::vector<int>>("seeds");
config.nested = program.get<int>("nested"); config.nested = program.get<int>("nested");
config.continue_from = program.get<std::string>("continue"); config.continue_from = program.get<std::string>("continue");
if (config.continue_from == "No" && config.only) { if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
throw std::runtime_error("Cannot use --only without --continue"); throw std::runtime_error("Cannot use --only without --continue");
} }
dump = program.get<bool>("dump"); dump = program.get<bool>("dump");
compute = program.get<bool>("compute"); compute = program.get<bool>("compute");
if (dump && (config.continue_from != "No" || config.only)) { if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only"); throw std::runtime_error("Cannot use --dump with --continue or --only");
} }
} }
@ -177,6 +177,7 @@ int main(int argc, char** argv)
* Begin Processing * Begin Processing
*/ */
auto env = platform::DotEnv(); auto env = platform::DotEnv();
config.platform = env.get("platform");
platform::Paths::createPath(platform::Paths::grid()); platform::Paths::createPath(platform::Paths::grid());
auto grid_search = platform::GridSearch(config); auto grid_search = platform::GridSearch(config);
platform::Timer timer; platform::Timer timer;
@ -185,10 +186,7 @@ int main(int argc, char** argv)
list_dump(config.model); list_dump(config.model);
} else { } else {
if (compute) { if (compute) {
if (config.nested == 0) grid_search.go();
grid_search.goSingle();
else
grid_search.goNested();
std::cout << "Process took " << timer.getDurationString() << std::endl; std::cout << "Process took " << timer.getDurationString() << std::endl;
} else { } else {
// List results // List results