Build gridsearch structure

This commit is contained in:
Ricardo Montañana Gómez 2023-11-20 23:32:34 +01:00
parent 5876be4b24
commit 4628e48d3c
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
13 changed files with 197 additions and 33 deletions

3
.gitignore vendored
View File

@ -32,8 +32,7 @@
*.out
*.app
build/**
build_debug/**
build_release/**
build_*/**
*.dSYM/**
cmake-build*/**
.idea

View File

@ -4,7 +4,7 @@ SHELL := /bin/bash
f_release = build_release
f_debug = build_debug
app_targets = b_best b_list b_main b_manage
app_targets = b_best b_list b_main b_manage b_grid
test_targets = unit_tests_bayesnet unit_tests_platform
n_procs = -j 16
@ -36,10 +36,10 @@ install: ## Copy binary files to bin folder
@echo "Destination folder: $(dest)"
make buildr
@echo ">>> Copying files to $(dest)"
@cp $(f_release)/src/Platform/b_main $(dest)
@cp $(f_release)/src/Platform/b_list $(dest)
@cp $(f_release)/src/Platform/b_manage $(dest)
@cp $(f_release)/src/Platform/b_best $(dest)
for item in $(app_targets); do \
echo ">>> Copying $$item" ; \
cp $(f_release)/src/Platform/$$item $(dest) ; \
done
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
@echo ">>> Creating dependency graph diagram of the project...";

View File

@ -1,6 +1,6 @@
#include <iostream>
#include <torch/torch.h>
#include <std::string>
#include <string>
#include <map>
#include <argparse/argparse.hpp>
#include <nlohmann/json.hpp>

View File

@ -8,12 +8,14 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
include_directories(${Python3_INCLUDE_DIRS})
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
add_executable(b_grid b_grid.cc GridSearch.cc Folding.cc)
add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc)
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
target_link_libraries(b_grid BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)

View File

@ -3,30 +3,16 @@
#include <torch/torch.h>
#include <nlohmann/json.hpp>
#include <string>
#include <chrono>
#include "Folding.h"
#include "BaseClassifier.h"
#include "HyperParameters.h"
#include "TAN.h"
#include "KDB.h"
#include "AODE.h"
#include "Timer.h"
namespace platform {
using json = nlohmann::json;
class Timer {
private:
std::chrono::high_resolution_clock::time_point begin;
public:
Timer() = default;
~Timer() = default;
void start() { begin = std::chrono::high_resolution_clock::now(); }
double getDuration()
{
std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
return time_span.count();
}
};
class Result {
private:
std::string dataset, model_version;

View File

@ -0,0 +1,32 @@
#include "GridSearch.h"
namespace platform {
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
{
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
}
void GridSearch::go()
{
// // Load datasets
// auto datasets = platform::Datasets(config.input_file);
// // Load hyperparameters
// auto hyperparameters = platform::HyperParameters(datasets.getNames(), config.input_file);
// // Check if hyperparameters are valid
// auto valid_hyperparameters = platform::Models::instance()->getHyperparameters(config.model);
// hyperparameters.check(valid_hyperparameters, config.model);
// // Load model
// auto model = platform::Models::instance()->get(config.model);
// // Run gridsearch
// auto grid = platform::Grid(datasets, hyperparameters, model, config.score, config.discretize, config.stratified, config.n_folds, config.seeds);
// grid.run();
// // Save results
// grid.save(config.output_file);
}
void GridSearch::save()
{
}
} /* namespace platform */

30
src/Platform/GridSearch.h Normal file
View File

@ -0,0 +1,30 @@
#ifndef GRIDSEARCH_H
#define GRIDSEARCH_H
#include <string>
#include <vector>
namespace platform {
struct ConfigGrid {
std::string model;
std::string score;
std::string path;
std::string input_file;
std::string output_file;
bool discretize;
bool stratified;
int n_folds;
std::vector<int> seeds;
};
class GridSearch {
public:
explicit GridSearch(struct ConfigGrid& config);
void go();
void save();
~GridSearch() = default;
private:
struct ConfigGrid config;
};
} /* namespace platform */
#endif /* GRIDSEARCH_H */

View File

@ -9,6 +9,7 @@ namespace platform {
static std::string hiddenResults() { return "hidden_results/"; }
static std::string excel() { return "excel/"; }
static std::string cfs() { return "cfs/"; }
static std::string grid() { return "grid/"; }
static std::string datasets()
{
auto env = platform::DotEnv();

34
src/Platform/Timer.h Normal file
View File

@ -0,0 +1,34 @@
#ifndef TIMER_H
#define TIMER_H
#include <chrono>
#include <string>
#include <sstream>
namespace platform {
class Timer {
private:
std::chrono::high_resolution_clock::time_point begin;
std::chrono::high_resolution_clock::time_point end;
public:
Timer() = default;
~Timer() = default;
void start() { begin = std::chrono::high_resolution_clock::now(); }
void stop() { end = std::chrono::high_resolution_clock::now(); }
double getDuration()
{
stop();
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
return time_span.count();
}
std::string getDurationString()
{
double duration = getDuration();
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
std::stringstream ss;
ss << std::setw(7) << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit << " ";
return ss.str();
}
};
} /* namespace platform */
#endif /* TIMER_H */

View File

@ -7,7 +7,7 @@
argparse::ArgumentParser manageArguments(int argc, char** argv)
{
argparse::ArgumentParser program("best");
argparse::ArgumentParser program("b_sbest");
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);

80
src/Platform/b_grid.cc Normal file
View File

@ -0,0 +1,80 @@
#include <iostream>
#include <argparse/argparse.hpp>
#include "DotEnv.h"
#include "Models.h"
#include "modelRegister.h"
#include "GridSearch.h"
#include "Paths.h"
#include "Timer.h"
argparse::ArgumentParser manageArguments(std::string program_name)
{
auto env = platform::DotEnv();
argparse::ArgumentParser program(program_name);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring())
.action([](const std::string& value) {
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
}
);
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
try {
auto k = stoi(value);
if (k < 2) {
throw std::runtime_error("Number of folds must be greater than 1");
}
return k;
}
catch (const runtime_error& err) {
throw std::runtime_error(err.what());
}
catch (...) {
throw std::runtime_error("Number of folds must be an integer");
}});
auto seed_values = env.getSeeds();
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
return program;
}
int main(int argc, char** argv)
{
auto program = manageArguments("b_grid");
struct platform::ConfigGrid config;
try {
program.parse_args(argc, argv);
config.model = program.get<std::string>("model");
config.score = program.get<std::string>("score");
config.discretize = program.get<bool>("discretize");
config.stratified = program.get<bool>("stratified");
config.n_folds = program.get<int>("folds");
config.seeds = program.get<std::vector<int>>("seeds");
}
catch (const exception& err) {
cerr << err.what() << std::endl;
cerr << program;
exit(1);
}
/*
* Begin Processing
*/
auto env = platform::DotEnv();
config.path = platform::Paths::grid();
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
timer.start();
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
grid_search.save();
std::cout << "Done!" << std::endl;
return 0;
}

View File

@ -11,10 +11,10 @@
using json = nlohmann::json;
argparse::ArgumentParser manageArguments()
argparse::ArgumentParser manageArguments(std::string program_name)
{
auto env = platform::DotEnv();
argparse::ArgumentParser program("main");
argparse::ArgumentParser program(program_name);
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
@ -61,7 +61,7 @@ int main(int argc, char** argv)
std::vector<int> seeds;
std::vector<std::string> filesToTest;
int n_folds;
auto program = manageArguments();
auto program = manageArguments("b_main");
try {
program.parse_args(argc, argv);
file_name = program.get<std::string>("dataset");

View File

@ -5,7 +5,7 @@
argparse::ArgumentParser manageArguments(int argc, char** argv)
{
argparse::ArgumentParser program("manage");
argparse::ArgumentParser program("b_manage");
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");