Add reports to gridsearch
This commit is contained in:
parent
8dbbb65a2f
commit
460d20a402
10
.gitmodules
vendored
10
.gitmodules
vendored
@ -1,15 +1,25 @@
|
||||
[submodule "lib/mdlp"]
|
||||
path = lib/mdlp
|
||||
url = https://github.com/rmontanana/mdlp
|
||||
main = main
|
||||
update = merge
|
||||
[submodule "lib/catch2"]
|
||||
path = lib/catch2
|
||||
main = v2.x
|
||||
update = merge
|
||||
url = https://github.com/catchorg/Catch2.git
|
||||
[submodule "lib/argparse"]
|
||||
path = lib/argparse
|
||||
url = https://github.com/p-ranav/argparse
|
||||
master = master
|
||||
update = merge
|
||||
[submodule "lib/json"]
|
||||
path = lib/json
|
||||
url = https://github.com/nlohmann/json.git
|
||||
master = master
|
||||
update = merge
|
||||
[submodule "lib/libxlsxwriter"]
|
||||
path = lib/libxlsxwriter
|
||||
url = https://github.com/jmcnamara/libxlsxwriter.git
|
||||
main = main
|
||||
update = merge
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit b0930ab0288185815d6dc67af59de7014a6272f7
|
||||
Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6
|
@ -32,6 +32,18 @@ namespace platform {
|
||||
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||
}
|
||||
std::vector<json> GridSearch::dump()
|
||||
{
|
||||
return GridData(config.input_file).getGrid();
|
||||
}
|
||||
json GridSearch::getResults()
|
||||
{
|
||||
std::ifstream file(config.output_file);
|
||||
if (file.is_open()) {
|
||||
return json::parse(file);
|
||||
}
|
||||
return json();
|
||||
}
|
||||
void showProgressComb(const int num, const int total, const std::string& color)
|
||||
{
|
||||
int spaces = int(log(total) / log(10)) + 1;
|
||||
|
@ -28,6 +28,8 @@ namespace platform {
|
||||
explicit GridSearch(struct ConfigGrid& config);
|
||||
void go();
|
||||
~GridSearch() = default;
|
||||
std::vector<json> dump();
|
||||
json getResults();
|
||||
private:
|
||||
void save(json& results) const;
|
||||
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||
|
@ -5,9 +5,8 @@
|
||||
#include "Colors.h"
|
||||
|
||||
|
||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
||||
{
|
||||
argparse::ArgumentParser program("b_sbest");
|
||||
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
||||
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
|
||||
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
|
||||
@ -28,12 +27,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
catch (...) {
|
||||
throw std::runtime_error("Number of folds must be an decimal number");
|
||||
}});
|
||||
return program;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto program = manageArguments(argc, argv);
|
||||
argparse::ArgumentParser program("b_sbest");
|
||||
manageArguments(program, argc, argv);
|
||||
std::string model, score;
|
||||
bool build, report, friedman, excel;
|
||||
double level;
|
||||
|
@ -6,12 +6,13 @@
|
||||
#include "GridSearch.h"
|
||||
#include "Paths.h"
|
||||
#include "Timer.h"
|
||||
#include "Colors.h"
|
||||
|
||||
|
||||
argparse::ArgumentParser manageArguments(std::string program_name)
|
||||
void manageArguments(argparse::ArgumentParser& program)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
argparse::ArgumentParser program(program_name);
|
||||
auto& group = program.add_mutually_exclusive_group(true);
|
||||
program.add_argument("-m", "--model")
|
||||
.help("Model to use " + platform::Models::instance()->tostring())
|
||||
.action([](const std::string& value) {
|
||||
@ -22,6 +23,9 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
||||
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
||||
}
|
||||
);
|
||||
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
||||
group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true);
|
||||
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||
@ -44,13 +48,14 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
||||
}});
|
||||
auto seed_values = env.getSeeds();
|
||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||
return program;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto program = manageArguments("b_grid");
|
||||
argparse::ArgumentParser program("b_grid");
|
||||
manageArguments(program);
|
||||
struct platform::ConfigGrid config;
|
||||
bool dump, compute, list;
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
config.model = program.get<std::string>("model");
|
||||
@ -65,6 +70,12 @@ int main(int argc, char** argv)
|
||||
if (config.continue_from == "No" && config.only) {
|
||||
throw std::runtime_error("Cannot use --only without --continue");
|
||||
}
|
||||
dump = program.get<bool>("dump");
|
||||
compute = program.get<bool>("compute");
|
||||
list = program.get<bool>("list");
|
||||
if (dump && (config.continue_from != "No" || config.only)) {
|
||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||
}
|
||||
}
|
||||
catch (const exception& err) {
|
||||
cerr << err.what() << std::endl;
|
||||
@ -80,8 +91,73 @@ int main(int argc, char** argv)
|
||||
auto grid_search = platform::GridSearch(config);
|
||||
platform::Timer timer;
|
||||
timer.start();
|
||||
grid_search.go();
|
||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||
if (dump) {
|
||||
auto combinations = grid_search.dump();
|
||||
auto total = combinations.size();
|
||||
int spaces = int(log(total) / log(10)) + 1;
|
||||
std::cout << Colors::MAGENTA() << "There are " << total << " combinations" << std::endl << std::endl;
|
||||
int index = 0;
|
||||
int max = 0;
|
||||
for (auto const& item : combinations) {
|
||||
if (item.dump().size() > spaces) {
|
||||
max = item.dump().size();
|
||||
}
|
||||
}
|
||||
std::cout << Colors::GREEN() << left << setw(spaces) << "#" << left << " " << setw(spaces)
|
||||
<< "Hyperparameters" << std::endl;
|
||||
std::cout << string(spaces, '=') << " " << string(max, '=') << std::endl;
|
||||
bool odd = true;
|
||||
for (auto const& item : combinations) {
|
||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||
std::cout << color;
|
||||
std::cout << setw(spaces) << fixed << right << ++index << left << " " << item.dump() << std::endl;
|
||||
odd = !odd;
|
||||
}
|
||||
std::cout << Colors::RESET() << std::endl;
|
||||
} else {
|
||||
if (compute) {
|
||||
grid_search.go();
|
||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||
} else {
|
||||
std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model "
|
||||
<< config.model << std::endl << std::endl;
|
||||
auto results = grid_search.getResults();
|
||||
if (results.empty()) {
|
||||
std::cout << "No results found" << std::endl;
|
||||
} else {
|
||||
int spaces = 0;
|
||||
int hyperparameters_spaces = 0;
|
||||
for (const auto& item : results.items()) {
|
||||
auto key = item.key();
|
||||
auto value = item.value();
|
||||
if (key.size() > spaces) {
|
||||
spaces = key.size();
|
||||
}
|
||||
if (value["hyperparameters"].dump().size() > hyperparameters_spaces) {
|
||||
hyperparameters_spaces = value["hyperparameters"].dump().size();
|
||||
}
|
||||
}
|
||||
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
||||
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||
bool odd = true;
|
||||
int index = 0;
|
||||
for (const auto& item : results.items()) {
|
||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||
auto key = item.key();
|
||||
auto value = item.value();
|
||||
std::cout << color;
|
||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
||||
<< " " << setw(8) << setprecision(6) << fixed << right
|
||||
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||
odd = !odd;
|
||||
}
|
||||
std::cout << Colors::RESET() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "Done!" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
@ -11,10 +11,9 @@
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
argparse::ArgumentParser manageArguments(std::string program_name)
|
||||
void manageArguments(argparse::ArgumentParser& program)
|
||||
{
|
||||
auto env = platform::DotEnv();
|
||||
argparse::ArgumentParser program(program_name);
|
||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||
@ -50,18 +49,18 @@ argparse::ArgumentParser manageArguments(std::string program_name)
|
||||
}});
|
||||
auto seed_values = env.getSeeds();
|
||||
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||
return program;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
argparse::ArgumentParser program("b_main");
|
||||
manageArguments(program);
|
||||
std::string file_name, model_name, title, hyperparameters_file;
|
||||
json hyperparameters_json;
|
||||
bool discretize_dataset, stratified, saveResults, quiet;
|
||||
std::vector<int> seeds;
|
||||
std::vector<std::string> filesToTest;
|
||||
int n_folds;
|
||||
auto program = manageArguments("b_main");
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
file_name = program.get<std::string>("dataset");
|
||||
|
@ -3,9 +3,8 @@
|
||||
#include "ManageResults.h"
|
||||
|
||||
|
||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
void manageArguments(argparse::ArgumentParser& program, int argc, char** argv)
|
||||
{
|
||||
argparse::ArgumentParser program("b_manage");
|
||||
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
||||
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||
@ -29,12 +28,12 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
std::cerr << program;
|
||||
exit(1);
|
||||
}
|
||||
return program;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto program = manageArguments(argc, argv);
|
||||
auto program = argparse::ArgumentParser("b_manage");
|
||||
manageArguments(program, argc, argv);
|
||||
int number = program.get<int>("number");
|
||||
std::string model = program.get<std::string>("model");
|
||||
std::string score = program.get<std::string>("score");
|
||||
|
Loading…
Reference in New Issue
Block a user