Add reports to gridsearch
This commit is contained in:
parent
8dbbb65a2f
commit
460d20a402
10
.gitmodules
vendored
10
.gitmodules
vendored
@ -1,15 +1,25 @@
|
|||||||
[submodule "lib/mdlp"]
|
[submodule "lib/mdlp"]
|
||||||
path = lib/mdlp
|
path = lib/mdlp
|
||||||
url = https://github.com/rmontanana/mdlp
|
url = https://github.com/rmontanana/mdlp
|
||||||
|
main = main
|
||||||
|
update = merge
|
||||||
[submodule "lib/catch2"]
|
[submodule "lib/catch2"]
|
||||||
path = lib/catch2
|
path = lib/catch2
|
||||||
|
main = v2.x
|
||||||
|
update = merge
|
||||||
url = https://github.com/catchorg/Catch2.git
|
url = https://github.com/catchorg/Catch2.git
|
||||||
[submodule "lib/argparse"]
|
[submodule "lib/argparse"]
|
||||||
path = lib/argparse
|
path = lib/argparse
|
||||||
url = https://github.com/p-ranav/argparse
|
url = https://github.com/p-ranav/argparse
|
||||||
|
master = master
|
||||||
|
update = merge
|
||||||
[submodule "lib/json"]
|
[submodule "lib/json"]
|
||||||
path = lib/json
|
path = lib/json
|
||||||
url = https://github.com/nlohmann/json.git
|
url = https://github.com/nlohmann/json.git
|
||||||
|
master = master
|
||||||
|
update = merge
|
||||||
[submodule "lib/libxlsxwriter"]
|
[submodule "lib/libxlsxwriter"]
|
||||||
path = lib/libxlsxwriter
|
path = lib/libxlsxwriter
|
||||||
url = https://github.com/jmcnamara/libxlsxwriter.git
|
url = https://github.com/jmcnamara/libxlsxwriter.git
|
||||||
|
main = main
|
||||||
|
update = merge
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit b0930ab0288185815d6dc67af59de7014a6272f7
|
Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6
|
@ -32,6 +32,18 @@ namespace platform {
|
|||||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||||
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||||
}
|
}
|
||||||
|
std::vector<json> GridSearch::dump()
|
||||||
|
{
|
||||||
|
return GridData(config.input_file).getGrid();
|
||||||
|
}
|
||||||
|
json GridSearch::getResults()
|
||||||
|
{
|
||||||
|
std::ifstream file(config.output_file);
|
||||||
|
if (file.is_open()) {
|
||||||
|
return json::parse(file);
|
||||||
|
}
|
||||||
|
return json();
|
||||||
|
}
|
||||||
void showProgressComb(const int num, const int total, const std::string& color)
|
void showProgressComb(const int num, const int total, const std::string& color)
|
||||||
{
|
{
|
||||||
int spaces = int(log(total) / log(10)) + 1;
|
int spaces = int(log(total) / log(10)) + 1;
|
||||||
|
@ -28,6 +28,8 @@ namespace platform {
|
|||||||
explicit GridSearch(struct ConfigGrid& config);
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
void go();
|
void go();
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
|
std::vector<json> dump();
|
||||||
|
json getResults();
|
||||||
private:
|
private:
|
||||||
void save(json& results) const;
|
void save(json& results) const;
|
||||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
|
@ -5,9 +5,8 @@
|
|||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("b_sbest");
|
|
||||||
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
||||||
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
|
||||||
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
|
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
|
||||||
@ -28,12 +27,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
catch (...) {
|
catch (...) {
|
||||||
throw std::runtime_error("Number of folds must be an decimal number");
|
throw std::runtime_error("Number of folds must be an decimal number");
|
||||||
}});
|
}});
|
||||||
return program;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = manageArguments(argc, argv);
|
argparse::ArgumentParser program("b_sbest");
|
||||||
|
manageArguments(program, argc, argv);
|
||||||
std::string model, score;
|
std::string model, score;
|
||||||
bool build, report, friedman, excel;
|
bool build, report, friedman, excel;
|
||||||
double level;
|
double level;
|
||||||
|
@ -6,12 +6,13 @@
|
|||||||
#include "GridSearch.h"
|
#include "GridSearch.h"
|
||||||
#include "Paths.h"
|
#include "Paths.h"
|
||||||
#include "Timer.h"
|
#include "Timer.h"
|
||||||
|
#include "Colors.h"
|
||||||
|
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(std::string program_name)
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
argparse::ArgumentParser program(program_name);
|
auto& group = program.add_mutually_exclusive_group(true);
|
||||||
program.add_argument("-m", "--model")
|
program.add_argument("-m", "--model")
|
||||||
.help("Model to use " + platform::Models::instance()->tostring())
|
.help("Model to use " + platform::Models::instance()->tostring())
|
||||||
.action([](const std::string& value) {
|
.action([](const std::string& value) {
|
||||||
@ -22,6 +23,9 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
|||||||
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
||||||
|
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
|
||||||
|
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
@ -44,13 +48,14 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
|||||||
}});
|
}});
|
||||||
auto seed_values = env.getSeeds();
|
auto seed_values = env.getSeeds();
|
||||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||||
return program;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = manageArguments("b_grid");
|
argparse::ArgumentParser program("b_grid");
|
||||||
|
manageArguments(program);
|
||||||
struct platform::ConfigGrid config;
|
struct platform::ConfigGrid config;
|
||||||
|
bool dump, compute, list;
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
config.model = program.get<std::string>("model");
|
config.model = program.get<std::string>("model");
|
||||||
@ -65,6 +70,12 @@ int main(int argc, char** argv)
|
|||||||
if (config.continue_from == "No" && config.only) {
|
if (config.continue_from == "No" && config.only) {
|
||||||
throw std::runtime_error("Cannot use --only without --continue");
|
throw std::runtime_error("Cannot use --only without --continue");
|
||||||
}
|
}
|
||||||
|
dump = program.get<bool>("dump");
|
||||||
|
compute = program.get<bool>("compute");
|
||||||
|
list = program.get<bool>("list");
|
||||||
|
if (dump && (config.continue_from != "No" || config.only)) {
|
||||||
|
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
catch (const exception& err) {
|
catch (const exception& err) {
|
||||||
cerr << err.what() << std::endl;
|
cerr << err.what() << std::endl;
|
||||||
@ -80,8 +91,73 @@ int main(int argc, char** argv)
|
|||||||
auto grid_search = platform::GridSearch(config);
|
auto grid_search = platform::GridSearch(config);
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
timer.start();
|
timer.start();
|
||||||
grid_search.go();
|
if (dump) {
|
||||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
auto combinations = grid_search.dump();
|
||||||
|
auto total = combinations.size();
|
||||||
|
int spaces = int(log(total) / log(10)) + 1;
|
||||||
|
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
|
||||||
|
int index = 0;
|
||||||
|
int max = 0;
|
||||||
|
for (auto const& item : combinations) {
|
||||||
|
if (item.dump().size() > spaces) {
|
||||||
|
max = item.dump().size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
|
||||||
|
<< "Hyperparameters" << std::endl;
|
||||||
|
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
|
||||||
|
bool odd = true;
|
||||||
|
for (auto const& item : combinations) {
|
||||||
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
|
std::cout << color;
|
||||||
|
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
} else {
|
||||||
|
if (compute) {
|
||||||
|
grid_search.go();
|
||||||
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
||||||
|
<< config.model << std::endl << std::endl;
|
||||||
|
auto results = grid_search.getResults();
|
||||||
|
if (results.empty()) {
|
||||||
|
std::cout << "No results found" << std::endl;
|
||||||
|
} else {
|
||||||
|
int spaces = 0;
|
||||||
|
int hyperparameters_spaces = 0;
|
||||||
|
for (const auto& item : results.items()) {
|
||||||
|
auto key = item.key();
|
||||||
|
auto value = item.value();
|
||||||
|
if (key.size() > spaces) {
|
||||||
|
spaces = key.size();
|
||||||
|
}
|
||||||
|
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
|
||||||
|
hyperparameters_spaces = value["hyperparameters"].dump().size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
||||||
|
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||||
|
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||||
|
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||||
|
bool odd = true;
|
||||||
|
int index = 0;
|
||||||
|
for (const auto& item : results.items()) {
|
||||||
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
|
auto key = item.key();
|
||||||
|
auto value = item.value();
|
||||||
|
std::cout << color;
|
||||||
|
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||||
|
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
||||||
|
<< " " << setw(8) << setprecision(6) << fixed << right
|
||||||
|
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
std::cout << Colors::RESET() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
std::cout << "Done!" << std::endl;
|
std::cout << "Done!" << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,9 @@
|
|||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(std::string program_name)
|
void manageArguments(argparse::ArgumentParser& program)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
argparse::ArgumentParser program(program_name);
|
|
||||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
||||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||||
@ -50,18 +49,18 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
|||||||
}});
|
}});
|
||||||
auto seed_values = env.getSeeds();
|
auto seed_values = env.getSeeds();
|
||||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||||
return program;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
|
argparse::ArgumentParser program("b_main");
|
||||||
|
manageArguments(program);
|
||||||
std::string file_name, model_name, title, hyperparameters_file;
|
std::string file_name, model_name, title, hyperparameters_file;
|
||||||
json hyperparameters_json;
|
json hyperparameters_json;
|
||||||
bool discretize_dataset, stratified, saveResults, quiet;
|
bool discretize_dataset, stratified, saveResults, quiet;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
std::vector<std::string> filesToTest;
|
std::vector<std::string> filesToTest;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
auto program = manageArguments("b_main");
|
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
file_name = program.get<std::string>("dataset");
|
file_name = program.get<std::string>("dataset");
|
||||||
|
@ -3,9 +3,8 @@
|
|||||||
#include "ManageResults.h"
|
#include "ManageResults.h"
|
||||||
|
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("b_manage");
|
|
||||||
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
||||||
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
||||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||||
@ -29,12 +28,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
|||||||
std::cerr << program;
|
std::cerr << program;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
return program;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto program = manageArguments(argc, argv);
|
auto program = argparse::ArgumentParser("b_manage");
|
||||||
|
manageArguments(program, argc, argv);
|
||||||
int number = program.get<int>("number");
|
int number = program.get<int>("number");
|
||||||
std::string model = program.get<std::string>("model");
|
std::string model = program.get<std::string>("model");
|
||||||
std::string score = program.get<std::string>("score");
|
std::string score = program.get<std::string>("score");
|
||||||
|
Loading…
Reference in New Issue
Block a user