diff --git a/.gitignore b/.gitignore index 424b902..268bb77 100644 --- a/.gitignore +++ b/.gitignore @@ -32,8 +32,7 @@ *.out *.app build/** -build_debug/** -build_release/** +build_*/** *.dSYM/** cmake-build*/** .idea diff --git a/Makefile b/Makefile index f6650a2..cb82162 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ SHELL := /bin/bash f_release = build_release f_debug = build_debug -app_targets = b_best b_list b_main b_manage +app_targets = b_best b_list b_main b_manage b_grid test_targets = unit_tests_bayesnet unit_tests_platform n_procs = -j 16 @@ -36,10 +36,10 @@ install: ## Copy binary files to bin folder @echo "Destination folder: $(dest)" make buildr @echo ">>> Copying files to $(dest)" - @cp $(f_release)/src/Platform/b_main $(dest) - @cp $(f_release)/src/Platform/b_list $(dest) - @cp $(f_release)/src/Platform/b_manage $(dest) - @cp $(f_release)/src/Platform/b_best $(dest) + for item in $(app_targets); do \ + echo ">>> Copying $$item" ; \ + cp $(f_release)/src/Platform/$$item $(dest) ; \ + done dependency: ## Create a dependency graph diagram of the project (build/dependency.png) @echo ">>> Creating dependency graph diagram of the project..."; diff --git a/sample/sample.cc b/sample/sample.cc index d5f84e9..8024707 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index fba1656..8fc33a4 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -8,12 +8,14 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${Python3_INCLUDE_DIRS}) +add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) +add_executable(b_grid b_grid.cc GridSearch.cc Folding.cc) +add_executable(b_list b_list.cc Datasets.cc Dataset.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) -add_executable(b_list b_list.cc Datasets.cc Dataset.cc) -add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) -target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap) -target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp) -target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file +target_link_libraries(b_grid BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap) +target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") +target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap) +target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) \ No newline at end of file diff --git a/src/Platform/Experiment.h b/src/Platform/Experiment.h index c00d7ff..b7aeda6 100644 --- a/src/Platform/Experiment.h +++ b/src/Platform/Experiment.h @@ -3,30 +3,16 @@ #include #include #include -#include #include "Folding.h" #include "BaseClassifier.h" #include "HyperParameters.h" #include "TAN.h" #include "KDB.h" #include "AODE.h" +#include "Timer.h" namespace platform { using json = nlohmann::json; - class Timer { - private: - std::chrono::high_resolution_clock::time_point begin; - public: - Timer() = default; - ~Timer() = default; - void start() { begin = std::chrono::high_resolution_clock::now(); } - double getDuration() - { - std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); - std::chrono::duration time_span = std::chrono::duration_cast> (end - begin); - return time_span.count(); - } - }; class Result { private: std::string dataset, model_version; diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc new file mode 100644 index 0000000..1a9d1f2 --- /dev/null +++ b/src/Platform/GridSearch.cc @@ -0,0 +1,32 @@ +#include "GridSearch.h" + +namespace platform { + + GridSearch::GridSearch(struct ConfigGrid& config) : config(config) + { + this->config.input_file = config.path + "grid_" + config.model + "_input.json"; + this->config.output_file = config.path + "grid_" + config.model + "_output.json"; + } + void GridSearch::go() + { + // // Load datasets + // auto datasets = platform::Datasets(config.input_file); + // // Load hyperparameters + // auto hyperparameters = platform::HyperParameters(datasets.getNames(), config.input_file); + // // Check if hyperparameters are valid + // auto valid_hyperparameters = platform::Models::instance()->getHyperparameters(config.model); + // hyperparameters.check(valid_hyperparameters, config.model); + // // Load model + // auto model = platform::Models::instance()->get(config.model); + // // Run gridsearch + // auto grid = platform::Grid(datasets, hyperparameters, model, config.score, config.discretize, config.stratified, config.n_folds, config.seeds); + // grid.run(); + // // Save results + // grid.save(config.output_file); + } + void GridSearch::save() + { + + } + +} /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h new file mode 100644 index 0000000..9ded996 --- /dev/null +++ b/src/Platform/GridSearch.h @@ -0,0 +1,30 @@ +#ifndef GRIDSEARCH_H +#define GRIDSEARCH_H +#include +#include + +namespace platform { + struct ConfigGrid { + std::string model; + std::string score; + std::string path; + std::string input_file; + std::string output_file; + bool discretize; + bool stratified; + int n_folds; + std::vector seeds; + }; + class GridSearch { + public: + explicit GridSearch(struct ConfigGrid& config); + void go(); + void save(); + ~GridSearch() = default; + private: + struct ConfigGrid config; + + }; + +} /* namespace platform */ +#endif /* GRIDSEARCH_H */ \ No newline at end of file diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index c642cae..d3c4422 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -9,6 +9,7 @@ namespace platform { static std::string hiddenResults() { return "hidden_results/"; } static std::string excel() { return "excel/"; } static std::string cfs() { return "cfs/"; } + static std::string grid() { return "grid/"; } static std::string datasets() { auto env = platform::DotEnv(); diff --git a/src/Platform/Timer.h b/src/Platform/Timer.h new file mode 100644 index 0000000..87db481 --- /dev/null +++ b/src/Platform/Timer.h @@ -0,0 +1,34 @@ +#ifndef TIMER_H +#define TIMER_H +#include +#include +#include + +namespace platform { + class Timer { + private: + std::chrono::high_resolution_clock::time_point begin; + std::chrono::high_resolution_clock::time_point end; + public: + Timer() = default; + ~Timer() = default; + void start() { begin = std::chrono::high_resolution_clock::now(); } + void stop() { end = std::chrono::high_resolution_clock::now(); } + double getDuration() + { + stop(); + std::chrono::duration time_span = std::chrono::duration_cast> (end - begin); + return time_span.count(); + } + std::string getDurationString() + { + double duration = getDuration(); + double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration; + std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s"; + std::stringstream ss; + ss << std::setw(7) << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit << " "; + return ss.str(); + } + }; +} /* namespace platform */ +#endif /* TIMER_H */ \ No newline at end of file diff --git a/src/Platform/b_best.cc b/src/Platform/b_best.cc index b559d03..1ed73c7 100644 --- a/src/Platform/b_best.cc +++ b/src/Platform/b_best.cc @@ -7,7 +7,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) { - argparse::ArgumentParser program("best"); + argparse::ArgumentParser program("b_sbest"); program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)"); program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied"); program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true); diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc new file mode 100644 index 0000000..0d0851c --- /dev/null +++ b/src/Platform/b_grid.cc @@ -0,0 +1,80 @@ +#include +#include +#include "DotEnv.h" +#include "Models.h" +#include "modelRegister.h" +#include "GridSearch.h" +#include "Paths.h" +#include "Timer.h" + + + +argparse::ArgumentParser manageArguments(std::string program_name) +{ + auto env = platform::DotEnv(); + argparse::ArgumentParser program(program_name); + program.add_argument("-m", "--model") + .help("Model to use " + platform::Models::instance()->tostring()) + .action([](const std::string& value) { + static const std::vector choices = platform::Models::instance()->getNames(); + if (find(choices.begin(), choices.end(), value) != choices.end()) { + return value; + } + throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring()); + } + ); + program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); + program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); + program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { + try { + auto k = stoi(value); + if (k < 2) { + throw std::runtime_error("Number of folds must be greater than 1"); + } + return k; + } + catch (const runtime_error& err) { + throw std::runtime_error(err.what()); + } + catch (...) { + throw std::runtime_error("Number of folds must be an integer"); + }}); + auto seed_values = env.getSeeds(); + program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); + return program; +} + +int main(int argc, char** argv) +{ + auto program = manageArguments("b_grid"); + struct platform::ConfigGrid config; + try { + program.parse_args(argc, argv); + config.model = program.get("model"); + config.score = program.get("score"); + config.discretize = program.get("discretize"); + config.stratified = program.get("stratified"); + config.n_folds = program.get("folds"); + config.seeds = program.get>("seeds"); + } + catch (const exception& err) { + cerr << err.what() << std::endl; + cerr << program; + exit(1); + } + + /* + * Begin Processing + */ + auto env = platform::DotEnv(); + config.path = platform::Paths::grid(); + auto grid_search = platform::GridSearch(config); + platform::Timer timer; + timer.start(); + grid_search.go(); + std::cout << "Process took " << timer.getDurationString() << std::endl; + grid_search.save(); + std::cout << "Done!" << std::endl; + return 0; +} diff --git a/src/Platform/b_main.cc b/src/Platform/b_main.cc index bf2c703..c09f071 100644 --- a/src/Platform/b_main.cc +++ b/src/Platform/b_main.cc @@ -11,10 +11,10 @@ using json = nlohmann::json; -argparse::ArgumentParser manageArguments() +argparse::ArgumentParser manageArguments(std::string program_name) { auto env = platform::DotEnv(); - argparse::ArgumentParser program("main"); + argparse::ArgumentParser program(program_name); program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ @@ -61,7 +61,7 @@ int main(int argc, char** argv) std::vector seeds; std::vector filesToTest; int n_folds; - auto program = manageArguments(); + auto program = manageArguments("b_main"); try { program.parse_args(argc, argv); file_name = program.get("dataset"); diff --git a/src/Platform/b_manage.cc b/src/Platform/b_manage.cc index d4b6fa1..1067902 100644 --- a/src/Platform/b_manage.cc +++ b/src/Platform/b_manage.cc @@ -5,7 +5,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) { - argparse::ArgumentParser program("manage"); + argparse::ArgumentParser program("b_manage"); program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>(); program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");