From 5dd3deca1ac9307f79a18655f05d85ed21d31200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 7 Jun 2024 09:00:51 +0200 Subject: [PATCH] Add discretiz algorithm management to b_main & Dataset --- src/commands/b_main.cpp | 2 +- src/common/Dataset.cpp | 15 ++++++++------- src/common/Datasets.h | 10 +++++++++- src/grid/GridSearch.cpp | 3 ++- src/main/Experiment.cpp | 25 ++++++++----------------- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index ba5a945..89021ea 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -111,7 +111,7 @@ int main(int argc, char** argv) cerr << program; exit(1); } - auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); + auto datasets = platform::Datasets(false, platform::Paths::datasets()); if (datasets_file != "") { ifstream catalog(datasets_file); if (catalog.is_open()) { diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index 0af0bfa..194e23d 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -2,6 +2,7 @@ #include #include "Dataset.h" namespace platform { + const std::string message_dataset_not_loaded = "Dataset not loaded."; Dataset::Dataset(const Dataset& dataset) : path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features), @@ -23,7 +24,7 @@ namespace platform { if (loaded) { return features; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } int Dataset::getNFeatures() const @@ -31,7 +32,7 @@ namespace platform { if (loaded) { return n_features; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } int Dataset::getNSamples() const @@ -39,7 +40,7 @@ namespace platform { if (loaded) { return n_samples; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } std::map> Dataset::getStates() const @@ -47,7 +48,7 @@ namespace platform { if (loaded) { return states; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } pair>&, std::vector&> Dataset::getVectors() @@ -55,7 +56,7 @@ namespace platform { if (loaded) { return { Xv, yv }; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } pair>&, std::vector&> Dataset::getVectorsDiscretized() @@ -63,7 +64,7 @@ namespace platform { if (loaded) { return { Xd, yv }; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } pair Dataset::getTensors() @@ -72,7 +73,7 @@ namespace platform { buildTensors(); return { X, y }; } else { - throw std::invalid_argument("Dataset not loaded."); + throw std::invalid_argument(message_dataset_not_loaded); } } void Dataset::load_csv() diff --git a/src/common/Datasets.h b/src/common/Datasets.h index df63c02..44c8a05 100644 --- a/src/common/Datasets.h +++ b/src/common/Datasets.h @@ -4,7 +4,13 @@ namespace platform { class Datasets { public: - explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); }; + explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algo = "none") : discretize(discretize), sfileType(sfileType), discretizer_algo(discretizer_algo) + { + if (discretizer_algo == "none" && discretize) { + throw std::runtime_error("Can't discretize without discretization algorithm"); + } + load(); + }; std::vector getNames(); std::vector getFeatures(const std::string& name) const; int getNSamples(const std::string& name) const; @@ -17,6 +23,7 @@ namespace platform { std::pair>&, std::vector&> getVectors(const std::string& name); std::pair>&, std::vector&> getVectorsDiscretized(const std::string& name); std::pair getTensors(const std::string& name); + std::tuple getTrainTestTensors(const std::vector& train_idx, const std::vector& test_idx); bool isDataset(const std::string& name) const; void loadDataset(const std::string& name) const; std::string toString() const; @@ -24,6 +31,7 @@ namespace platform { std::string path; fileType_t fileType; std::string sfileType; + std::string discretizer_algo; std::map> datasets; bool discretize; void load(); // Loads the list of datasets diff --git a/src/grid/GridSearch.cpp b/src/grid/GridSearch.cpp index ef8b8da..5ea92a8 100644 --- a/src/grid/GridSearch.cpp +++ b/src/grid/GridSearch.cpp @@ -373,7 +373,8 @@ namespace platform { MPI_Bcast(msg, tasks_size + 1, MPI_CHAR, config_mpi.manager, MPI_COMM_WORLD); tasks = json::parse(msg); delete[] msg; - auto datasets = Datasets(config.discretize, Paths::datasets()); + auto env = platform::DotEnv(); + auto datasets = Datasets(config.discretize, Paths::datasets(), env.get("discretiz_algo")); if (config_mpi.rank == config_mpi.manager) { // // 2a. Producer delivers the tasks to the consumers diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index cd6e825..c5fb739 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -117,20 +117,19 @@ namespace platform { { auto datasets = Datasets(false, Paths::datasets()); // Never discretize here // Get dataset - auto [X, y] = datasets.getTensors(fileName); - auto states = datasets.getStates(fileName); + // -------------- auto [X, y] = datasets.getTensors(fileName); + // -------------- auto states = datasets.getStates(fileName); auto features = datasets.getFeatures(fileName); auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); auto labels = datasets.getLabels(fileName); - int num_classes = states[className].size() == 0 ? labels.size() : states[className].size(); + int num_classes = labels.size(); if (!quiet) { std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } // Prepare Result auto partial_result = PartialResult(); - auto [values, counts] = at::_unique(y); - partial_result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); + partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes); partial_result.setHyperparameters(hyperparameters.get(fileName)); // Initialize results std::vectors int nResults = nfolds * static_cast(randomSeeds.size()); @@ -170,18 +169,10 @@ namespace platform { // Split train - test dataset train_timer.start(); auto [train, test] = fold->getFold(nfold); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - auto X_train = X.index({ "...", train_t }); - auto y_train = y.index({ train_t }); - auto X_test = X.index({ "...", test_t }); - auto y_test = y.index({ test_t }); - if (discretized) { - // compute states too - // discretizer->fit(X_train, y_train); - // X_train = discretizer->transform(X_train); - // X_test = discretizer->transform(X_test); - } + auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test); + // Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva + // una referencia al objeto dataset y trabajar directamente con él. + auto states = datasets.getStates(fileName); if (generate_fold_files) generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test); if (!quiet)