From 495d8a8528b857904447db0840c6bec7c341f22c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 21 Nov 2023 13:11:14 +0100 Subject: [PATCH] Begin implementing grid combinations --- Makefile | 4 ++- src/Platform/CMakeLists.txt | 4 +-- src/Platform/GridData.cc | 67 +++++++++++++++++++++++++++++++++++++ src/Platform/GridData.h | 21 ++++++++++++ src/Platform/GridSearch.cc | 18 ++++++---- src/Platform/GridSearch.h | 4 +-- src/Platform/b_grid.cc | 2 -- 7 files changed, 107 insertions(+), 13 deletions(-) create mode 100644 src/Platform/GridData.cc create mode 100644 src/Platform/GridData.h diff --git a/Makefile b/Makefile index cb82162..2d3c9c1 100644 --- a/Makefile +++ b/Makefile @@ -35,8 +35,10 @@ dest ?= ${HOME}/bin install: ## Copy binary files to bin folder @echo "Destination folder: $(dest)" make buildr + @echo "*******************************************" @echo ">>> Copying files to $(dest)" - for item in $(app_targets); do \ + @echo "*******************************************" + @for item in $(app_targets); do \ echo ">>> Copying $$item" ; \ cp $(f_release)/src/Platform/$$item $(dest) ; \ done diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 8fc33a4..6063bdc 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -9,13 +9,13 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${Python3_INCLUDE_DIRS}) add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) -add_executable(b_grid b_grid.cc GridSearch.cc Folding.cc) +add_executable(b_grid b_grid.cc GridSearch.cc GridData.cc Folding.cc Datasets.cc Dataset.cc) add_executable(b_list b_list.cc Datasets.cc Dataset.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp) -target_link_libraries(b_grid BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap) +target_link_libraries(b_grid BayesNet PyWrap) target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap) target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) \ No newline at end of file diff --git a/src/Platform/GridData.cc b/src/Platform/GridData.cc new file mode 100644 index 0000000..6514616 --- /dev/null +++ b/src/Platform/GridData.cc @@ -0,0 +1,67 @@ +#include "GridData.h" +#include + +namespace platform { + GridData::GridData() + { + auto boostaode = R"( + [ + { + "convergence": [true, false], + "ascending": [true, false], + "repeatSparent": [true, false], + "select_features": ["CFS", "FCBF"], + "tolerance": [0, 3, 5], + "threshold": [1e-7] + }, + { + "convergence": [true, false], + "ascending": [true, false], + "repeatSparent": [true, false], + "select_features": ["IWSS"], + "tolerance": [0, 3, 5], + "threshold": [0.5] + + } + ] + )"_json; + grid["BoostAODE"] = boostaode; + } + int GridData::computeNumCombinations(const json& line) + { + int numCombinations = 1; + for (const auto& item : line) { + for (const auto& hyperparam : item.items()) { + numCombinations *= item.size(); + } + } + return numCombinations; + } + std::vector GridData::doCombination(const std::string& model) + { + int numTotal = 0; + for (const auto& item : grid[model]) { + numTotal += computeNumCombinations(item); + } + auto result = std::vector(numTotal); + int base = 0; + for (const auto& item : grid[model]) { + int numCombinations = computeNumCombinations(item); + int line = 0; + for (const auto& hyperparam : item.items()) { + int numValues = hyperparam.value().size(); + for (const auto& value : hyperparam.value()) { + for (int i = 0; i < numCombinations / numValues; i++) { + result[base + line++][hyperparam.key()] = value; + //std::cout << "line=" << base + line << " " << hyperparam.key() << "=" << value << std::endl; + } + } + } + base += numCombinations; + } + for (const auto& item : result) { + std::cout << item.dump() << std::endl; + } + return result; + } +} /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridData.h b/src/Platform/GridData.h new file mode 100644 index 0000000..de60986 --- /dev/null +++ b/src/Platform/GridData.h @@ -0,0 +1,21 @@ +#ifndef GRIDDATA_H +#define GRIDDATA_H +#include +#include +#include +#include + +namespace platform { + using json = nlohmann::json; + class GridData { + public: + GridData(); + ~GridData() = default; + std::vector getGrid(const std::string& model) { return doCombination(model); } + private: + int computeNumCombinations(const json& line); + std::vector doCombination(const std::string& model); + std::map grid; + }; +} /* namespace platform */ +#endif /* GRIDDATA_H */ \ No newline at end of file diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 1a9d1f2..3bd3f67 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -1,19 +1,25 @@ +#include #include "GridSearch.h" +#include "Paths.h" +#include "Datasets.h" +#include "HyperParameters.h" namespace platform { - GridSearch::GridSearch(struct ConfigGrid& config) : config(config) { - this->config.input_file = config.path + "grid_" + config.model + "_input.json"; this->config.output_file = config.path + "grid_" + config.model + "_output.json"; } void GridSearch::go() { - // // Load datasets - // auto datasets = platform::Datasets(config.input_file); - // // Load hyperparameters + // Load datasets + auto datasets = platform::Datasets(config.discretize, Paths::datasets()); + int i = 0; + for (const auto& item : grid.getGrid("BoostAODE")) { + std::cout << i++ << " hyperparams: " << item.dump() << std::endl; + } + // Load hyperparameters // auto hyperparameters = platform::HyperParameters(datasets.getNames(), config.input_file); - // // Check if hyperparameters are valid + // Check if hyperparameters are valid // auto valid_hyperparameters = platform::Models::instance()->getHyperparameters(config.model); // hyperparameters.check(valid_hyperparameters, config.model); // // Load model diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 9ded996..1db303c 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -2,6 +2,7 @@ #define GRIDSEARCH_H #include #include +#include "GridData.h" namespace platform { struct ConfigGrid { @@ -23,8 +24,7 @@ namespace platform { ~GridSearch() = default; private: struct ConfigGrid config; - + GridData grid; }; - } /* namespace platform */ #endif /* GRIDSEARCH_H */ \ No newline at end of file diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 0d0851c..de905bf 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -8,7 +8,6 @@ #include "Timer.h" - argparse::ArgumentParser manageArguments(std::string program_name) { auto env = platform::DotEnv(); @@ -63,7 +62,6 @@ int main(int argc, char** argv) cerr << program; exit(1); } - /* * Begin Processing */