refactor gridsearch to have only one go method

This commit is contained in:
Ricardo Montañana Gómez 2023-12-02 10:59:05 +01:00
parent 33cd32c639
commit 03e4437fea
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 176 additions and 137 deletions

3
.vscode/launch.json vendored
View File

@ -46,7 +46,8 @@
"--discretize",
"--continue",
"glass",
"--only"
"--only",
"--compute"
],
"cwd": "${workspaceFolder}/../discretizbench",
},

View File

@ -38,10 +38,10 @@ namespace platform {
}
return json();
}
void showProgressComb(const int num, const int total, const std::string& color)
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
{
int spaces = int(log(total) / log(10)) + 1;
int magic = 37 + 2 * spaces;
int magic = n_folds * 3 + 22 + 2 * spaces;
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
}
@ -63,18 +63,120 @@ namespace platform {
return Colors::RESET();
}
}
double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
void GridSearch::go()
{
timer.start();
auto grid_type = config.nested == 0 ? "Single" : "Nested";
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting " << grid_type << " Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
Timer timer_dataset;
double bestScore = 0;
json bestHyperparameters;
for (const auto& dataset : datasets_names) {
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
auto combinations = grid.getGrid(dataset);
timer_dataset.start();
if (config.nested == 0)
// for dataset // for hyperparameters // for seed // for fold
tie(bestScore, bestHyperparameters) = processFileSingle(dataset, datasets, combinations);
else
// for dataset // for seed // for fold // for hyperparameters // for nested fold
tie(bestScore, bestHyperparameters) = processFileNested(dataset, datasets, combinations);
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) },
{ "duration", timer_dataset.getDurationString() }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending " << grid_type << " Gridsearch *******************" << std::endl;
}
pair<double, json> GridSearch::processFileSingle(std::string fileName, Datasets& datasets, vector<json>& combinations)
{
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto totalComb = combinations.size();
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, config.n_folds, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto className = datasets.getClassName(fileName);
double totalScore = 0.0;
int numItems = 0;
for (const auto& seed : config.seeds) {
if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed);
else
fold = new KFold(config.n_folds, y.size(0), seed);
for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
auto X_train = X.index({ "...", train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_train, y_train, features, className, states);
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test);
numItems++;
if (!config.quiet)
std::cout << "\b\b\b, \b" << flush;
}
delete fold;
}
double score = numItems == 0 ? 0.0 : totalScore / numItems;
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
return { bestScore, bestHyperparameters };
}
pair<double, json> GridSearch::processFileNested(std::string fileName, Datasets& datasets, vector<json>& combinations)
{
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto className = datasets.getClassName(fileName);
double totalScore = 0.0;
double bestScore = 0.0;
json bestHyperparameters;
int numItems = 0;
// for dataset // for seed // for fold // for hyperparameters // for nested fold
for (const auto& seed : config.seeds) {
if (!config.quiet)
std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold;
if (config.stratified)
fold = new StratifiedKFold(config.n_folds, y, seed);
@ -82,10 +184,7 @@ namespace platform {
fold = new KFold(config.n_folds, y.size(0), seed);
double bestScore = 0.0;
for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
// First level fold
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
@ -93,28 +192,50 @@ namespace platform {
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_train, y_train, features, className, states);
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test);
numItems++;
if (!config.quiet)
std::cout << "\b\b\b, \b" << flush;
for (const auto& hyperparam_line : combinations) {
Fold* nested_fold;
if (config.stratified)
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
else
nested_fold = new KFold(config.nested, y_train.size(0), seed);
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
// Nested level fold
auto [train_nested, test_nested] = fold->getFold(n_nested_fold);
auto train_nested_t = torch::tensor(train_nested);
auto test_nested_t = torch::tensor(test_nested);
auto X_nexted_train = X_train.index({ "...", train_nested_t });
auto y_nested_train = y_train.index({ train_nested_t });
auto X_nested_test = X_train.index({ "...", test_nested_t });
auto y_nested_test = y_train.index({ test_nested_t });
// Build Classifier with selected hyperparameters
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
// Train model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
clf->fit(X_nexted_train, y_nested_train, features, className, states);
// Test model
if (!config.quiet)
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
bestScore += clf->score(X_nested_test, y_nested_test);
}
delete nested_fold;
}
}
delete fold;
}
return numItems == 0 ? 0.0 : totalScore / numItems;
return { bestScore, bestHyperparameters };
}
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
{
// Load datasets
auto datasets_names = datasets.getNames();
if (config.continue_from != "No") {
if (config.continue_from != NO_CONTINUE()) {
// Continue previous execution:
// remove datasets already processed
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
@ -139,7 +260,7 @@ namespace platform {
{
// Load previous results
json results;
if (config.continue_from != "No") {
if (config.continue_from != NO_CONTINUE()) {
if (!config.quiet)
std::cout << "* Loading previous results" << std::endl;
try {
@ -157,100 +278,7 @@ namespace platform {
}
return results;
}
void GridSearch::goSingle()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Single Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending Single Gridsearch *******************" << std::endl;
}
void GridSearch::goNested()
{
auto datasets = Datasets(config.discretize, Paths::datasets());
auto datasets_names = processDatasets(datasets);
json results = initializeResults();
std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl;
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
auto grid = GridData(Paths::grid_input(config.model));
// Generate hyperparameters grid & run gridsearch
// Check each combination of hyperparameters for each dataset and each seed
for (const auto& dataset : datasets_names) {
auto totalComb = grid.getNumCombinations(dataset);
if (!config.quiet)
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
auto combinations = grid.getGrid(dataset);
for (const auto& hyperparam_line : combinations) {
if (!config.quiet)
showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
double score = processFileSingle(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
}
if (!config.quiet) {
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
}
json result = {
{ "score", bestScore },
{ "hyperparameters", bestHyperparameters },
{ "date", get_date() + " " + get_time() },
{ "grid", grid.getInputGrid(dataset) }
};
results[dataset] = result;
// Save partial results
save(results);
}
// Save final results
save(results);
std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl;
}
void GridSearch::save(json& results) const
void GridSearch::save(json& results)
{
std::ofstream file(Paths::grid_output(config.model));
json output = {
@ -262,7 +290,10 @@ namespace platform {
{ "seeds", config.seeds },
{ "date", get_date() + " " + get_time()},
{ "nested", config.nested},
{ "platform", config.platform },
{ "duration", timer.getDurationString(true)},
{ "results", results }
};
file << output.dump(4);
}

View File

@ -6,6 +6,7 @@
#include "Datasets.h"
#include "HyperParameters.h"
#include "GridData.h"
#include "Timer.h"
namespace platform {
using json = nlohmann::json;
@ -13,6 +14,7 @@ namespace platform {
std::string model;
std::string score;
std::string continue_from;
std::string platform;
bool quiet;
bool only; // used with continue_from to only compute that dataset
bool discretize;
@ -24,16 +26,18 @@ namespace platform {
class GridSearch {
public:
explicit GridSearch(struct ConfigGrid& config);
void goSingle();
void goNested();
void go();
~GridSearch() = default;
json getResults();
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
private:
void save(json& results) const;
void save(json& results);
json initializeResults();
vector<std::string> processDatasets(Datasets& datasets);
double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
pair<double, json> processFileSingle(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
pair<double, json> processFileNested(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
struct ConfigGrid config;
Timer timer; // used to measure the time of the whole process
};
} /* namespace platform */
#endif /* GRIDSEARCH_H */

View File

@ -20,9 +20,14 @@ namespace platform {
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
return time_span.count();
}
std::string getDurationString()
double getLapse()
{
double duration = getDuration();
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (std::chrono::high_resolution_clock::now() - begin);
return time_span.count();
}
std::string getDurationString(bool lapse = false)
{
double duration = lapse ? getLapse() : getDuration();
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
std::stringstream ss;

View File

@ -33,7 +33,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
@ -95,7 +95,7 @@ void list_results(json& results, std::string& model)
{
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
std::cout << headerLine("Listing computed hyperparameters for model " + model);
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
std::cout << headerLine("Date & time: " + results["date"].get<std::string>() + " Duration: " + results["duration"].get<std::string>());
std::cout << headerLine("Score: " + results["score"].get<std::string>());
std::cout << headerLine(
"Random seeds: " + results["seeds"].dump()
@ -118,9 +118,9 @@ void list_results(json& results, std::string& model)
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results["results"].items()) {
@ -130,8 +130,8 @@ void list_results(json& results, std::string& model)
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
<< " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
@ -159,12 +159,12 @@ int main(int argc, char** argv)
config.seeds = program.get<std::vector<int>>("seeds");
config.nested = program.get<int>("nested");
config.continue_from = program.get<std::string>("continue");
if (config.continue_from == "No" && config.only) {
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
throw std::runtime_error("Cannot use --only without --continue");
}
dump = program.get<bool>("dump");
compute = program.get<bool>("compute");
if (dump && (config.continue_from != "No" || config.only)) {
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only");
}
}
@ -177,6 +177,7 @@ int main(int argc, char** argv)
* Begin Processing
*/
auto env = platform::DotEnv();
config.platform = env.get("platform");
platform::Paths::createPath(platform::Paths::grid());
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
@ -185,10 +186,7 @@ int main(int argc, char** argv)
list_dump(config.model);
} else {
if (compute) {
if (config.nested == 0)
grid_search.goSingle();
else
grid_search.goNested();
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
} else {
// List results