refactor gridsearch to have only one go method
This commit is contained in:
parent
33cd32c639
commit
03e4437fea
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@ -46,7 +46,8 @@
|
|||||||
"--discretize",
|
"--discretize",
|
||||||
"--continue",
|
"--continue",
|
||||||
"glass",
|
"glass",
|
||||||
"--only"
|
"--only",
|
||||||
|
"--compute"
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/../discretizbench",
|
"cwd": "${workspaceFolder}/../discretizbench",
|
||||||
},
|
},
|
||||||
|
@ -38,10 +38,10 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return json();
|
return json();
|
||||||
}
|
}
|
||||||
void showProgressComb(const int num, const int total, const std::string& color)
|
void showProgressComb(const int num, const int n_folds, const int total, const std::string& color)
|
||||||
{
|
{
|
||||||
int spaces = int(log(total) / log(10)) + 1;
|
int spaces = int(log(total) / log(10)) + 1;
|
||||||
int magic = 37 + 2 * spaces;
|
int magic = n_folds * 3 + 22 + 2 * spaces;
|
||||||
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
|
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
|
||||||
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
|
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
|
||||||
}
|
}
|
||||||
@ -63,18 +63,120 @@ namespace platform {
|
|||||||
return Colors::RESET();
|
return Colors::RESET();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
void GridSearch::go()
|
||||||
|
{
|
||||||
|
timer.start();
|
||||||
|
auto grid_type = config.nested == 0 ? "Single" : "Nested";
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
|
auto datasets_names = processDatasets(datasets);
|
||||||
|
json results = initializeResults();
|
||||||
|
std::cout << "***************** Starting " << grid_type << " Gridsearch *****************" << std::endl;
|
||||||
|
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||||
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
|
Timer timer_dataset;
|
||||||
|
double bestScore = 0;
|
||||||
|
json bestHyperparameters;
|
||||||
|
for (const auto& dataset : datasets_names) {
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
|
auto combinations = grid.getGrid(dataset);
|
||||||
|
timer_dataset.start();
|
||||||
|
if (config.nested == 0)
|
||||||
|
// for dataset // for hyperparameters // for seed // for fold
|
||||||
|
tie(bestScore, bestHyperparameters) = processFileSingle(dataset, datasets, combinations);
|
||||||
|
else
|
||||||
|
// for dataset // for seed // for fold // for hyperparameters // for nested fold
|
||||||
|
tie(bestScore, bestHyperparameters) = processFileNested(dataset, datasets, combinations);
|
||||||
|
if (!config.quiet) {
|
||||||
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
|
}
|
||||||
|
json result = {
|
||||||
|
{ "score", bestScore },
|
||||||
|
{ "hyperparameters", bestHyperparameters },
|
||||||
|
{ "date", get_date() + " " + get_time() },
|
||||||
|
{ "grid", grid.getInputGrid(dataset) },
|
||||||
|
{ "duration", timer_dataset.getDurationString() }
|
||||||
|
};
|
||||||
|
results[dataset] = result;
|
||||||
|
// Save partial results
|
||||||
|
save(results);
|
||||||
|
}
|
||||||
|
// Save final results
|
||||||
|
save(results);
|
||||||
|
std::cout << "***************** Ending " << grid_type << " Gridsearch *******************" << std::endl;
|
||||||
|
}
|
||||||
|
pair<double, json> GridSearch::processFileSingle(std::string fileName, Datasets& datasets, vector<json>& combinations)
|
||||||
|
{
|
||||||
|
int num = 0;
|
||||||
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
|
auto totalComb = combinations.size();
|
||||||
|
for (const auto& hyperparam_line : combinations) {
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressComb(++num, config.n_folds, totalComb, Colors::CYAN());
|
||||||
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
|
// Get dataset
|
||||||
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
|
auto states = datasets.getStates(fileName);
|
||||||
|
auto features = datasets.getFeatures(fileName);
|
||||||
|
auto className = datasets.getClassName(fileName);
|
||||||
|
double totalScore = 0.0;
|
||||||
|
int numItems = 0;
|
||||||
|
for (const auto& seed : config.seeds) {
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
|
Fold* fold;
|
||||||
|
if (config.stratified)
|
||||||
|
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||||
|
else
|
||||||
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
auto valid = clf->getValidHyperparameters();
|
||||||
|
hyperparameters.check(valid, fileName);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
|
auto [train, test] = fold->getFold(nfold);
|
||||||
|
auto train_t = torch::tensor(train);
|
||||||
|
auto test_t = torch::tensor(test);
|
||||||
|
auto X_train = X.index({ "...", train_t });
|
||||||
|
auto y_train = y.index({ train_t });
|
||||||
|
auto X_test = X.index({ "...", test_t });
|
||||||
|
auto y_test = y.index({ test_t });
|
||||||
|
// Train model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
// Test model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
totalScore += clf->score(X_test, y_test);
|
||||||
|
numItems++;
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "\b\b\b, \b" << flush;
|
||||||
|
}
|
||||||
|
delete fold;
|
||||||
|
}
|
||||||
|
double score = numItems == 0 ? 0.0 : totalScore / numItems;
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { bestScore, bestHyperparameters };
|
||||||
|
}
|
||||||
|
pair<double, json> GridSearch::processFileNested(std::string fileName, Datasets& datasets, vector<json>& combinations)
|
||||||
{
|
{
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
auto features = datasets.getFeatures(fileName);
|
auto features = datasets.getFeatures(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
double totalScore = 0.0;
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
int numItems = 0;
|
int numItems = 0;
|
||||||
|
// for dataset // for seed // for fold // for hyperparameters // for nested fold
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
if (!config.quiet)
|
|
||||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
if (config.stratified)
|
if (config.stratified)
|
||||||
fold = new StratifiedKFold(config.n_folds, y, seed);
|
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||||
@ -82,10 +184,7 @@ namespace platform {
|
|||||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
auto clf = Models::instance()->create(config.model);
|
// First level fold
|
||||||
auto valid = clf->getValidHyperparameters();
|
|
||||||
hyperparameters.check(valid, fileName);
|
|
||||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
@ -93,28 +192,50 @@ namespace platform {
|
|||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = X.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
// Train model
|
for (const auto& hyperparam_line : combinations) {
|
||||||
if (!config.quiet)
|
Fold* nested_fold;
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
if (config.stratified)
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
||||||
// Test model
|
else
|
||||||
if (!config.quiet)
|
nested_fold = new KFold(config.nested, y_train.size(0), seed);
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
|
||||||
totalScore += clf->score(X_test, y_test);
|
for (int n_nested_fold = 0; n_nested_fold < config.nested; n_nested_fold++) {
|
||||||
numItems++;
|
// Nested level fold
|
||||||
if (!config.quiet)
|
auto [train_nested, test_nested] = fold->getFold(n_nested_fold);
|
||||||
std::cout << "\b\b\b, \b" << flush;
|
auto train_nested_t = torch::tensor(train_nested);
|
||||||
|
auto test_nested_t = torch::tensor(test_nested);
|
||||||
|
auto X_nexted_train = X_train.index({ "...", train_nested_t });
|
||||||
|
auto y_nested_train = y_train.index({ train_nested_t });
|
||||||
|
auto X_nested_test = X_train.index({ "...", test_nested_t });
|
||||||
|
auto y_nested_test = y_train.index({ test_nested_t });
|
||||||
|
// Build Classifier with selected hyperparameters
|
||||||
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
auto valid = clf->getValidHyperparameters();
|
||||||
|
hyperparameters.check(valid, fileName);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
|
// Train model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
|
clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||||
|
// Test model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
bestScore += clf->score(X_nested_test, y_nested_test);
|
||||||
|
}
|
||||||
|
delete nested_fold;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
delete fold;
|
delete fold;
|
||||||
}
|
}
|
||||||
return numItems == 0 ? 0.0 : totalScore / numItems;
|
return { bestScore, bestHyperparameters };
|
||||||
}
|
}
|
||||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||||
{
|
{
|
||||||
// Load datasets
|
// Load datasets
|
||||||
|
|
||||||
auto datasets_names = datasets.getNames();
|
auto datasets_names = datasets.getNames();
|
||||||
if (config.continue_from != "No") {
|
if (config.continue_from != NO_CONTINUE()) {
|
||||||
// Continue previous execution:
|
// Continue previous execution:
|
||||||
// remove datasets already processed
|
// remove datasets already processed
|
||||||
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
if (std::find(datasets_names.begin(), datasets_names.end(), config.continue_from) == datasets_names.end()) {
|
||||||
@ -139,7 +260,7 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
// Load previous results
|
// Load previous results
|
||||||
json results;
|
json results;
|
||||||
if (config.continue_from != "No") {
|
if (config.continue_from != NO_CONTINUE()) {
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
std::cout << "* Loading previous results" << std::endl;
|
std::cout << "* Loading previous results" << std::endl;
|
||||||
try {
|
try {
|
||||||
@ -157,100 +278,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
void GridSearch::save(json& results)
|
||||||
void GridSearch::goSingle()
|
|
||||||
{
|
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
|
||||||
auto datasets_names = processDatasets(datasets);
|
|
||||||
json results = initializeResults();
|
|
||||||
std::cout << "***************** Starting Single Gridsearch *****************" << std::endl;
|
|
||||||
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
|
||||||
auto grid = GridData(Paths::grid_input(config.model));
|
|
||||||
// Generate hyperparameters grid & run gridsearch
|
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
|
||||||
for (const auto& dataset : datasets_names) {
|
|
||||||
auto totalComb = grid.getNumCombinations(dataset);
|
|
||||||
if (!config.quiet)
|
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
|
||||||
int num = 0;
|
|
||||||
double bestScore = 0.0;
|
|
||||||
json bestHyperparameters;
|
|
||||||
auto combinations = grid.getGrid(dataset);
|
|
||||||
for (const auto& hyperparam_line : combinations) {
|
|
||||||
if (!config.quiet)
|
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
|
||||||
double score = processFileSingle(dataset, datasets, hyperparameters);
|
|
||||||
if (score > bestScore) {
|
|
||||||
bestScore = score;
|
|
||||||
bestHyperparameters = hyperparam_line;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!config.quiet) {
|
|
||||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
|
||||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
|
||||||
}
|
|
||||||
json result = {
|
|
||||||
{ "score", bestScore },
|
|
||||||
{ "hyperparameters", bestHyperparameters },
|
|
||||||
{ "date", get_date() + " " + get_time() },
|
|
||||||
{ "grid", grid.getInputGrid(dataset) }
|
|
||||||
};
|
|
||||||
results[dataset] = result;
|
|
||||||
// Save partial results
|
|
||||||
save(results);
|
|
||||||
}
|
|
||||||
// Save final results
|
|
||||||
save(results);
|
|
||||||
std::cout << "***************** Ending Single Gridsearch *******************" << std::endl;
|
|
||||||
}
|
|
||||||
void GridSearch::goNested()
|
|
||||||
{
|
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
|
||||||
auto datasets_names = processDatasets(datasets);
|
|
||||||
json results = initializeResults();
|
|
||||||
std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl;
|
|
||||||
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
|
||||||
auto grid = GridData(Paths::grid_input(config.model));
|
|
||||||
// Generate hyperparameters grid & run gridsearch
|
|
||||||
// Check each combination of hyperparameters for each dataset and each seed
|
|
||||||
for (const auto& dataset : datasets_names) {
|
|
||||||
auto totalComb = grid.getNumCombinations(dataset);
|
|
||||||
if (!config.quiet)
|
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
|
||||||
int num = 0;
|
|
||||||
double bestScore = 0.0;
|
|
||||||
json bestHyperparameters;
|
|
||||||
auto combinations = grid.getGrid(dataset);
|
|
||||||
for (const auto& hyperparam_line : combinations) {
|
|
||||||
if (!config.quiet)
|
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
|
||||||
double score = processFileSingle(dataset, datasets, hyperparameters);
|
|
||||||
if (score > bestScore) {
|
|
||||||
bestScore = score;
|
|
||||||
bestHyperparameters = hyperparam_line;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!config.quiet) {
|
|
||||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
|
||||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
|
||||||
}
|
|
||||||
json result = {
|
|
||||||
{ "score", bestScore },
|
|
||||||
{ "hyperparameters", bestHyperparameters },
|
|
||||||
{ "date", get_date() + " " + get_time() },
|
|
||||||
{ "grid", grid.getInputGrid(dataset) }
|
|
||||||
};
|
|
||||||
results[dataset] = result;
|
|
||||||
// Save partial results
|
|
||||||
save(results);
|
|
||||||
}
|
|
||||||
// Save final results
|
|
||||||
save(results);
|
|
||||||
std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl;
|
|
||||||
}
|
|
||||||
void GridSearch::save(json& results) const
|
|
||||||
{
|
{
|
||||||
std::ofstream file(Paths::grid_output(config.model));
|
std::ofstream file(Paths::grid_output(config.model));
|
||||||
json output = {
|
json output = {
|
||||||
@ -262,7 +290,10 @@ namespace platform {
|
|||||||
{ "seeds", config.seeds },
|
{ "seeds", config.seeds },
|
||||||
{ "date", get_date() + " " + get_time()},
|
{ "date", get_date() + " " + get_time()},
|
||||||
{ "nested", config.nested},
|
{ "nested", config.nested},
|
||||||
|
{ "platform", config.platform },
|
||||||
|
{ "duration", timer.getDurationString(true)},
|
||||||
{ "results", results }
|
{ "results", results }
|
||||||
|
|
||||||
};
|
};
|
||||||
file << output.dump(4);
|
file << output.dump(4);
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
#include "GridData.h"
|
#include "GridData.h"
|
||||||
|
#include "Timer.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@ -13,6 +14,7 @@ namespace platform {
|
|||||||
std::string model;
|
std::string model;
|
||||||
std::string score;
|
std::string score;
|
||||||
std::string continue_from;
|
std::string continue_from;
|
||||||
|
std::string platform;
|
||||||
bool quiet;
|
bool quiet;
|
||||||
bool only; // used with continue_from to only compute that dataset
|
bool only; // used with continue_from to only compute that dataset
|
||||||
bool discretize;
|
bool discretize;
|
||||||
@ -24,16 +26,18 @@ namespace platform {
|
|||||||
class GridSearch {
|
class GridSearch {
|
||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void goSingle();
|
void go();
|
||||||
void goNested();
|
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
json getResults();
|
json getResults();
|
||||||
|
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
|
||||||
private:
|
private:
|
||||||
void save(json& results) const;
|
void save(json& results);
|
||||||
json initializeResults();
|
json initializeResults();
|
||||||
vector<std::string> processDatasets(Datasets& datasets);
|
vector<std::string> processDatasets(Datasets& datasets);
|
||||||
double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
pair<double, json> processFileSingle(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
|
||||||
|
pair<double, json> processFileNested(std::string fileName, Datasets& datasets, std::vector<json>& combinations);
|
||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
|
Timer timer; // used to measure the time of the whole process
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
#endif /* GRIDSEARCH_H */
|
#endif /* GRIDSEARCH_H */
|
@ -20,9 +20,14 @@ namespace platform {
|
|||||||
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
|
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
|
||||||
return time_span.count();
|
return time_span.count();
|
||||||
}
|
}
|
||||||
std::string getDurationString()
|
double getLapse()
|
||||||
{
|
{
|
||||||
double duration = getDuration();
|
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (std::chrono::high_resolution_clock::now() - begin);
|
||||||
|
return time_span.count();
|
||||||
|
}
|
||||||
|
std::string getDurationString(bool lapse = false)
|
||||||
|
{
|
||||||
|
double duration = lapse ? getLapse() : getDuration();
|
||||||
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
||||||
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
@ -33,7 +33,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
||||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
@ -95,7 +95,7 @@ void list_results(json& results, std::string& model)
|
|||||||
{
|
{
|
||||||
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
||||||
std::cout << headerLine("Listing computed hyperparameters for model " + model);
|
std::cout << headerLine("Listing computed hyperparameters for model " + model);
|
||||||
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
|
std::cout << headerLine("Date & time: " + results["date"].get<std::string>() + " Duration: " + results["duration"].get<std::string>());
|
||||||
std::cout << headerLine("Score: " + results["score"].get<std::string>());
|
std::cout << headerLine("Score: " + results["score"].get<std::string>());
|
||||||
std::cout << headerLine(
|
std::cout << headerLine(
|
||||||
"Random seeds: " + results["seeds"].dump()
|
"Random seeds: " + results["seeds"].dump()
|
||||||
@ -118,9 +118,9 @@ void list_results(json& results, std::string& model)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
||||||
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||||
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
|
||||||
bool odd = true;
|
bool odd = true;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (const auto& item : results["results"].items()) {
|
for (const auto& item : results["results"].items()) {
|
||||||
@ -130,8 +130,8 @@ void list_results(json& results, std::string& model)
|
|||||||
std::cout << color;
|
std::cout << color;
|
||||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||||
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
||||||
<< " " << setw(8) << setprecision(6) << fixed << right
|
<< " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
|
||||||
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||||
odd = !odd;
|
odd = !odd;
|
||||||
}
|
}
|
||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
@ -159,12 +159,12 @@ int main(int argc, char** argv)
|
|||||||
config.seeds = program.get<std::vector<int>>("seeds");
|
config.seeds = program.get<std::vector<int>>("seeds");
|
||||||
config.nested = program.get<int>("nested");
|
config.nested = program.get<int>("nested");
|
||||||
config.continue_from = program.get<std::string>("continue");
|
config.continue_from = program.get<std::string>("continue");
|
||||||
if (config.continue_from == "No" && config.only) {
|
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
|
||||||
throw std::runtime_error("Cannot use --only without --continue");
|
throw std::runtime_error("Cannot use --only without --continue");
|
||||||
}
|
}
|
||||||
dump = program.get<bool>("dump");
|
dump = program.get<bool>("dump");
|
||||||
compute = program.get<bool>("compute");
|
compute = program.get<bool>("compute");
|
||||||
if (dump && (config.continue_from != "No" || config.only)) {
|
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
||||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,6 +177,7 @@ int main(int argc, char** argv)
|
|||||||
* Begin Processing
|
* Begin Processing
|
||||||
*/
|
*/
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
|
config.platform = env.get("platform");
|
||||||
platform::Paths::createPath(platform::Paths::grid());
|
platform::Paths::createPath(platform::Paths::grid());
|
||||||
auto grid_search = platform::GridSearch(config);
|
auto grid_search = platform::GridSearch(config);
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
@ -185,10 +186,7 @@ int main(int argc, char** argv)
|
|||||||
list_dump(config.model);
|
list_dump(config.model);
|
||||||
} else {
|
} else {
|
||||||
if (compute) {
|
if (compute) {
|
||||||
if (config.nested == 0)
|
grid_search.go();
|
||||||
grid_search.goSingle();
|
|
||||||
else
|
|
||||||
grid_search.goNested();
|
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
} else {
|
} else {
|
||||||
// List results
|
// List results
|
||||||
|
Loading…
Reference in New Issue
Block a user