Add header to grid output and report
This commit is contained in:
parent
c460ef46ed
commit
33cd32c639
@ -63,7 +63,7 @@ namespace platform {
|
|||||||
return Colors::RESET();
|
return Colors::RESET();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
||||||
{
|
{
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
@ -135,11 +135,8 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return datasets_names;
|
return datasets_names;
|
||||||
}
|
}
|
||||||
|
json GridSearch::initializeResults()
|
||||||
void GridSearch::go()
|
|
||||||
{
|
{
|
||||||
auto datasets = Datasets(config.discretize, Paths::datasets());
|
|
||||||
auto datasets_names = processDatasets(datasets);
|
|
||||||
// Load previous results
|
// Load previous results
|
||||||
json results;
|
json results;
|
||||||
if (config.continue_from != "No") {
|
if (config.continue_from != "No") {
|
||||||
@ -149,6 +146,7 @@ namespace platform {
|
|||||||
std::ifstream file(Paths::grid_output(config.model));
|
std::ifstream file(Paths::grid_output(config.model));
|
||||||
if (file.is_open()) {
|
if (file.is_open()) {
|
||||||
results = json::parse(file);
|
results = json::parse(file);
|
||||||
|
results = results["results"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
@ -157,7 +155,15 @@ namespace platform {
|
|||||||
results = json();
|
results = json();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GridSearch::goSingle()
|
||||||
|
{
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
|
auto datasets_names = processDatasets(datasets);
|
||||||
|
json results = initializeResults();
|
||||||
|
std::cout << "***************** Starting Single Gridsearch *****************" << std::endl;
|
||||||
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||||
auto grid = GridData(Paths::grid_input(config.model));
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
// Generate hyperparameters grid & run gridsearch
|
// Generate hyperparameters grid & run gridsearch
|
||||||
@ -174,7 +180,7 @@ namespace platform {
|
|||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
double score = processFile(dataset, datasets, hyperparameters);
|
double score = processFileSingle(dataset, datasets, hyperparameters);
|
||||||
if (score > bestScore) {
|
if (score > bestScore) {
|
||||||
bestScore = score;
|
bestScore = score;
|
||||||
bestHyperparameters = hyperparam_line;
|
bestHyperparameters = hyperparam_line;
|
||||||
@ -184,20 +190,80 @@ namespace platform {
|
|||||||
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
}
|
}
|
||||||
results[dataset]["score"] = bestScore;
|
json result = {
|
||||||
results[dataset]["hyperparameters"] = bestHyperparameters;
|
{ "score", bestScore },
|
||||||
results[dataset]["date"] = get_date() + " " + get_time();
|
{ "hyperparameters", bestHyperparameters },
|
||||||
results[dataset]["grid"] = grid.getInputGrid(dataset);
|
{ "date", get_date() + " " + get_time() },
|
||||||
|
{ "grid", grid.getInputGrid(dataset) }
|
||||||
|
};
|
||||||
|
results[dataset] = result;
|
||||||
// Save partial results
|
// Save partial results
|
||||||
save(results);
|
save(results);
|
||||||
}
|
}
|
||||||
// Save final results
|
// Save final results
|
||||||
save(results);
|
save(results);
|
||||||
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
std::cout << "***************** Ending Single Gridsearch *******************" << std::endl;
|
||||||
|
}
|
||||||
|
void GridSearch::goNested()
|
||||||
|
{
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
|
auto datasets_names = processDatasets(datasets);
|
||||||
|
json results = initializeResults();
|
||||||
|
std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl;
|
||||||
|
std::cout << "input file=" << Paths::grid_input(config.model) << std::endl;
|
||||||
|
auto grid = GridData(Paths::grid_input(config.model));
|
||||||
|
// Generate hyperparameters grid & run gridsearch
|
||||||
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
|
for (const auto& dataset : datasets_names) {
|
||||||
|
auto totalComb = grid.getNumCombinations(dataset);
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
|
int num = 0;
|
||||||
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
|
auto combinations = grid.getGrid(dataset);
|
||||||
|
for (const auto& hyperparam_line : combinations) {
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
|
double score = processFileSingle(dataset, datasets, hyperparameters);
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!config.quiet) {
|
||||||
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
|
}
|
||||||
|
json result = {
|
||||||
|
{ "score", bestScore },
|
||||||
|
{ "hyperparameters", bestHyperparameters },
|
||||||
|
{ "date", get_date() + " " + get_time() },
|
||||||
|
{ "grid", grid.getInputGrid(dataset) }
|
||||||
|
};
|
||||||
|
results[dataset] = result;
|
||||||
|
// Save partial results
|
||||||
|
save(results);
|
||||||
|
}
|
||||||
|
// Save final results
|
||||||
|
save(results);
|
||||||
|
std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl;
|
||||||
}
|
}
|
||||||
void GridSearch::save(json& results) const
|
void GridSearch::save(json& results) const
|
||||||
{
|
{
|
||||||
std::ofstream file(Paths::grid_output(config.model));
|
std::ofstream file(Paths::grid_output(config.model));
|
||||||
file << results.dump(4);
|
json output = {
|
||||||
|
{ "model", config.model },
|
||||||
|
{ "score", config.score },
|
||||||
|
{ "discretize", config.discretize },
|
||||||
|
{ "stratified", config.stratified },
|
||||||
|
{ "n_folds", config.n_folds },
|
||||||
|
{ "seeds", config.seeds },
|
||||||
|
{ "date", get_date() + " " + get_time()},
|
||||||
|
{ "nested", config.nested},
|
||||||
|
{ "results", results }
|
||||||
|
};
|
||||||
|
file << output.dump(4);
|
||||||
}
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -17,19 +17,22 @@ namespace platform {
|
|||||||
bool only; // used with continue_from to only compute that dataset
|
bool only; // used with continue_from to only compute that dataset
|
||||||
bool discretize;
|
bool discretize;
|
||||||
bool stratified;
|
bool stratified;
|
||||||
|
int nested;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
};
|
};
|
||||||
class GridSearch {
|
class GridSearch {
|
||||||
public:
|
public:
|
||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void goSingle();
|
||||||
|
void goNested();
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
json getResults();
|
json getResults();
|
||||||
private:
|
private:
|
||||||
void save(json& results) const;
|
void save(json& results) const;
|
||||||
|
json initializeResults();
|
||||||
vector<std::string> processDatasets(Datasets& datasets);
|
vector<std::string> processDatasets(Datasets& datasets);
|
||||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
};
|
};
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
const int MAXL = 133;
|
||||||
|
|
||||||
void manageArguments(argparse::ArgumentParser& program)
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
{
|
{
|
||||||
@ -27,13 +28,14 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
||||||
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
|
group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
|
||||||
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
||||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
try {
|
try {
|
||||||
@ -83,13 +85,29 @@ void list_dump(std::string& model)
|
|||||||
}
|
}
|
||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
|
std::string headerLine(const std::string& text, int utf = 0)
|
||||||
|
{
|
||||||
|
int n = MAXL - text.length() - 3;
|
||||||
|
n = n < 0 ? 0 : n;
|
||||||
|
return "* " + text + std::string(n + utf, ' ') + "*\n";
|
||||||
|
}
|
||||||
void list_results(json& results, std::string& model)
|
void list_results(json& results, std::string& model)
|
||||||
{
|
{
|
||||||
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
||||||
<< model << std::endl << std::endl;
|
std::cout << headerLine("Listing computed hyperparameters for model " + model);
|
||||||
|
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
|
||||||
|
std::cout << headerLine("Score: " + results["score"].get<std::string>());
|
||||||
|
std::cout << headerLine(
|
||||||
|
"Random seeds: " + results["seeds"].dump()
|
||||||
|
+ " Discretized: " + (results["discretize"].get<bool>() ? "True" : "False")
|
||||||
|
+ " Stratified: " + (results["stratified"].get<bool>() ? "True" : "False")
|
||||||
|
+ " #Folds: " + std::to_string(results["n_folds"].get<int>())
|
||||||
|
+ " Nested: " + (results["nested"].get<int>() == 0 ? "False" : to_string(results["nested"].get<int>()))
|
||||||
|
);
|
||||||
|
std::cout << std::string(MAXL, '*') << std::endl;
|
||||||
int spaces = 0;
|
int spaces = 0;
|
||||||
int hyperparameters_spaces = 0;
|
int hyperparameters_spaces = 0;
|
||||||
for (const auto& item : results.items()) {
|
for (const auto& item : results["results"].items()) {
|
||||||
auto key = item.key();
|
auto key = item.key();
|
||||||
auto value = item.value();
|
auto value = item.value();
|
||||||
if (key.size() > spaces) {
|
if (key.size() > spaces) {
|
||||||
@ -105,7 +123,7 @@ void list_results(json& results, std::string& model)
|
|||||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||||
bool odd = true;
|
bool odd = true;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (const auto& item : results.items()) {
|
for (const auto& item : results["results"].items()) {
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
auto key = item.key();
|
auto key = item.key();
|
||||||
auto value = item.value();
|
auto value = item.value();
|
||||||
@ -119,12 +137,16 @@ void list_results(json& results, std::string& model)
|
|||||||
std::cout << Colors::RESET() << std::endl;
|
std::cout << Colors::RESET() << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Main
|
||||||
|
*/
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("b_grid");
|
argparse::ArgumentParser program("b_grid");
|
||||||
manageArguments(program);
|
manageArguments(program);
|
||||||
struct platform::ConfigGrid config;
|
struct platform::ConfigGrid config;
|
||||||
bool dump, compute, list;
|
bool dump, compute;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
config.model = program.get<std::string>("model");
|
config.model = program.get<std::string>("model");
|
||||||
@ -135,13 +157,13 @@ int main(int argc, char** argv)
|
|||||||
config.quiet = program.get<bool>("quiet");
|
config.quiet = program.get<bool>("quiet");
|
||||||
config.only = program.get<bool>("only");
|
config.only = program.get<bool>("only");
|
||||||
config.seeds = program.get<std::vector<int>>("seeds");
|
config.seeds = program.get<std::vector<int>>("seeds");
|
||||||
|
config.nested = program.get<int>("nested");
|
||||||
config.continue_from = program.get<std::string>("continue");
|
config.continue_from = program.get<std::string>("continue");
|
||||||
if (config.continue_from == "No" && config.only) {
|
if (config.continue_from == "No" && config.only) {
|
||||||
throw std::runtime_error("Cannot use --only without --continue");
|
throw std::runtime_error("Cannot use --only without --continue");
|
||||||
}
|
}
|
||||||
dump = program.get<bool>("dump");
|
dump = program.get<bool>("dump");
|
||||||
compute = program.get<bool>("compute");
|
compute = program.get<bool>("compute");
|
||||||
list = program.get<bool>("list");
|
|
||||||
if (dump && (config.continue_from != "No" || config.only)) {
|
if (dump && (config.continue_from != "No" || config.only)) {
|
||||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||||
}
|
}
|
||||||
@ -163,7 +185,10 @@ int main(int argc, char** argv)
|
|||||||
list_dump(config.model);
|
list_dump(config.model);
|
||||||
} else {
|
} else {
|
||||||
if (compute) {
|
if (compute) {
|
||||||
grid_search.go();
|
if (config.nested == 0)
|
||||||
|
grid_search.goSingle();
|
||||||
|
else
|
||||||
|
grid_search.goNested();
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
} else {
|
} else {
|
||||||
// List results
|
// List results
|
||||||
|
Loading…
Reference in New Issue
Block a user