diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 6063bdc..d35989f 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -9,7 +9,7 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) include_directories(${Python3_INCLUDE_DIRS}) add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) -add_executable(b_grid b_grid.cc GridSearch.cc GridData.cc Folding.cc Datasets.cc Dataset.cc) +add_executable(b_grid b_grid.cc GridSearch.cc GridData.cc HyperParameters.cc Folding.cc Datasets.cc Dataset.cc) add_executable(b_list b_list.cc Datasets.cc Dataset.cc) add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc) add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 5e80fc3..1574f73 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -133,7 +133,7 @@ namespace platform { } void Experiment::cross_validation(const std::string& fileName, bool quiet) { - auto datasets = platform::Datasets(discretized, Paths::datasets()); + auto datasets = Datasets(discretized, Paths::datasets()); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); diff --git a/src/Platform/GridData.cc b/src/Platform/GridData.cc index 6514616..5935d73 100644 --- a/src/Platform/GridData.cc +++ b/src/Platform/GridData.cc @@ -30,37 +30,41 @@ namespace platform { int GridData::computeNumCombinations(const json& line) { int numCombinations = 1; - for (const auto& item : line) { - for (const auto& hyperparam : item.items()) { - numCombinations *= item.size(); - } + for (const auto& item : line.items()) { + numCombinations *= item.value().size(); } return numCombinations; } - std::vector GridData::doCombination(const std::string& model) + int GridData::getNumCombinations(const std::string& model) { - int numTotal = 0; - for (const auto& item : grid[model]) { - numTotal += computeNumCombinations(item); + int numCombinations = 0; + for (const auto& line : grid.at(model)) { + numCombinations += computeNumCombinations(line); } - auto result = std::vector(numTotal); - int base = 0; - for (const auto& item : grid[model]) { - int numCombinations = computeNumCombinations(item); - int line = 0; - for (const auto& hyperparam : item.items()) { - int numValues = hyperparam.value().size(); - for (const auto& value : hyperparam.value()) { - for (int i = 0; i < numCombinations / numValues; i++) { - result[base + line++][hyperparam.key()] = value; - //std::cout << "line=" << base + line << " " << hyperparam.key() << "=" << value << std::endl; - } - } - } - base += numCombinations; + return numCombinations; + } + json GridData::generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination) + { + if (index == last) { + // If we reached the end of input, store the current combination + output.push_back(currentCombination); + return currentCombination; } - for (const auto& item : result) { - std::cout << item.dump() << std::endl; + const auto& key = index.key(); + const auto& values = index.value(); + for (const auto& value : values) { + auto combination = currentCombination; + combination[key] = value; + json::iterator nextIndex = index; + generateCombinations(++nextIndex, last, output, combination); + } + return currentCombination; + } + std::vector GridData::getGrid(const std::string& model) + { + auto result = std::vector(); + for (json line : grid.at(model)) { + generateCombinations(line.begin(), line.end(), result, json({})); } return result; } diff --git a/src/Platform/GridData.h b/src/Platform/GridData.h index de60986..87ab74c 100644 --- a/src/Platform/GridData.h +++ b/src/Platform/GridData.h @@ -11,10 +11,11 @@ namespace platform { public: GridData(); ~GridData() = default; - std::vector getGrid(const std::string& model) { return doCombination(model); } + std::vector getGrid(const std::string& model); + int getNumCombinations(const std::string& model); private: + json generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination); int computeNumCombinations(const json& line); - std::vector doCombination(const std::string& model); std::map grid; }; } /* namespace platform */ diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 3bd3f67..e5f072a 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -1,38 +1,91 @@ #include +#include #include "GridSearch.h" +#include "Models.h" #include "Paths.h" -#include "Datasets.h" -#include "HyperParameters.h" +#include "Folding.h" +#include "Colors.h" namespace platform { GridSearch::GridSearch(struct ConfigGrid& config) : config(config) { this->config.output_file = config.path + "grid_" + config.model + "_output.json"; } + void showProgress(int fold, const std::string& color, const std::string& phase) + { + std::string prefix = phase == "a" ? "" : "\b\b\b\b"; + std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; + } + std::string getColor(bayesnet::status_t status) + { + switch (status) { + case bayesnet::NORMAL: + return Colors::GREEN(); + case bayesnet::WARNING: + return Colors::YELLOW(); + case bayesnet::ERROR: + return Colors::RED(); + default: + return Colors::RESET(); + } + } + void GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) + { + // Get dataset + auto [X, y] = datasets.getTensors(fileName); + auto states = datasets.getStates(fileName); + auto features = datasets.getFeatures(fileName); + auto samples = datasets.getNSamples(fileName); + auto className = datasets.getClassName(fileName); + std::cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; + for (const auto& seed : config.seeds) { + std::cout << "(" << seed << ") doing Fold: " << flush; + Fold* fold; + if (config.stratified) + fold = new StratifiedKFold(config.n_folds, y, seed); + else + fold = new KFold(config.n_folds, y.size(0), seed); + for (int nfold = 0; nfold < config.n_folds; nfold++) { + auto clf = Models::instance()->create(config.model); + auto [train, test] = fold->getFold(nfold); + // auto train_t = torch::tensor(train); + // auto test_t = torch::tensor(test); + // auto X_train = X.index({ "...", train_t }); + // auto y_train = y.index({ train_t }); + // auto X_test = X.index({ "...", test_t }); + // auto y_test = y.index({ test_t }); + showProgress(nfold + 1, getColor(clf->getStatus()), "a"); + // Train model + // clf->fit(X_train, y_train, features, className, states); + showProgress(nfold + 1, getColor(clf->getStatus()), "b"); + } + delete fold; + } + } void GridSearch::go() { // Load datasets - auto datasets = platform::Datasets(config.discretize, Paths::datasets()); - int i = 0; - for (const auto& item : grid.getGrid("BoostAODE")) { - std::cout << i++ << " hyperparams: " << item.dump() << std::endl; + auto datasets = Datasets(config.discretize, Paths::datasets()); + // Create model + std::cout << "***************** Starting Gridsearch *****************" << std::endl; + std::cout << "* Doing " << grid.getNumCombinations(config.model) << " combinations for each dataset/seed/fold" << std::endl; + // Generate hyperparameters grid & run gridsearch + // Check each combination of hyperparameters for each dataset and each seed + for (const auto& dataset : datasets.getNames()) { + std::cout << "- " << setw(20) << left << dataset << " " << right << flush; + for (const auto& hyperparam_line : grid.getGrid(config.model)) { + auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); + processFile(dataset, datasets, hyperparameters); + } + std::cout << std::endl; } - // Load hyperparameters - // auto hyperparameters = platform::HyperParameters(datasets.getNames(), config.input_file); - // Check if hyperparameters are valid - // auto valid_hyperparameters = platform::Models::instance()->getHyperparameters(config.model); - // hyperparameters.check(valid_hyperparameters, config.model); - // // Load model - // auto model = platform::Models::instance()->get(config.model); - // // Run gridsearch - // auto grid = platform::Grid(datasets, hyperparameters, model, config.score, config.discretize, config.stratified, config.n_folds, config.seeds); - // grid.run(); - // // Save results - // grid.save(config.output_file); + // Save results + save(); } void GridSearch::save() { - + std::ofstream file(config.output_file); + // file << results.dump(4); + file.close(); } - } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 1db303c..220eccc 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -2,6 +2,8 @@ #define GRIDSEARCH_H #include #include +#include "Datasets.h" +#include "HyperParameters.h" #include "GridData.h" namespace platform { @@ -23,6 +25,7 @@ namespace platform { void save(); ~GridSearch() = default; private: + void processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); struct ConfigGrid config; GridData grid; }; diff --git a/src/Platform/combinations.cc b/src/Platform/combinations.cc deleted file mode 100644 index fbaca83..0000000 --- a/src/Platform/combinations.cc +++ /dev/null @@ -1,57 +0,0 @@ -#include -#include -#include - -using json = nlohmann::json; - -json generateCombinations(json::iterator index, const json::iterator last, std::vector& output, json currentCombination) -{ - if (index == last) { - // If we reached the end of input, store the current combination - output.push_back(currentCombination); - return currentCombination; - } - const auto& key = index.key(); - const auto& values = index.value(); - for (const auto& value : values) { - auto combination = currentCombination; - combination[key] = value; - json::iterator nextIndex = index; - generateCombinations(++nextIndex, last, output, combination); - } - return currentCombination; -} - -int main() -{ - json input = R"( - [ - { - "convergence": [true, false], - "ascending": [true, false], - "repeatSparent": [true, false], - "select_features": ["CFS", "FCBF"], - "tolerance": [0, 3, 5], - "threshold": [1e-7] - }, - { - "convergence": [true, false], - "ascending": [true, false], - "repeatSparent": [true, false], - "select_features": ["IWSS"], - "tolerance": [0, 3, 5], - "threshold": [0.5] - } - ] - )"_json; - auto output = std::vector(); - for (json line : input) { - generateCombinations(line.begin(), line.end(), output, json({})); - } - // Print the generated combinations - int i = 0; - for (const auto& item : output) { - std::cout << i++ << " " << item.dump() << std::endl; - } - return 0; -}