Add discretiz algorithm management to b_main & Dataset

This commit is contained in:
2024-06-07 09:00:51 +02:00
parent 2202a81782
commit 5dd3deca1a
5 changed files with 28 additions and 27 deletions

View File

@@ -111,7 +111,7 @@ int main(int argc, char** argv)
cerr << program; cerr << program;
exit(1); exit(1);
} }
auto datasets = platform::Datasets(discretize_dataset, platform::Paths::datasets()); auto datasets = platform::Datasets(false, platform::Paths::datasets());
if (datasets_file != "") { if (datasets_file != "") {
ifstream catalog(datasets_file); ifstream catalog(datasets_file);
if (catalog.is_open()) { if (catalog.is_open()) {

View File

@@ -2,6 +2,7 @@
#include <fstream> #include <fstream>
#include "Dataset.h" #include "Dataset.h"
namespace platform { namespace platform {
const std::string message_dataset_not_loaded = "Dataset not loaded.";
Dataset::Dataset(const Dataset& dataset) : Dataset::Dataset(const Dataset& dataset) :
path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples),
n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features), n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features),
@@ -23,7 +24,7 @@ namespace platform {
if (loaded) { if (loaded) {
return features; return features;
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
int Dataset::getNFeatures() const int Dataset::getNFeatures() const
@@ -31,7 +32,7 @@ namespace platform {
if (loaded) { if (loaded) {
return n_features; return n_features;
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
int Dataset::getNSamples() const int Dataset::getNSamples() const
@@ -39,7 +40,7 @@ namespace platform {
if (loaded) { if (loaded) {
return n_samples; return n_samples;
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
std::map<std::string, std::vector<int>> Dataset::getStates() const std::map<std::string, std::vector<int>> Dataset::getStates() const
@@ -47,7 +48,7 @@ namespace platform {
if (loaded) { if (loaded) {
return states; return states;
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
pair<std::vector<std::vector<float>>&, std::vector<int>&> Dataset::getVectors() pair<std::vector<std::vector<float>>&, std::vector<int>&> Dataset::getVectors()
@@ -55,7 +56,7 @@ namespace platform {
if (loaded) { if (loaded) {
return { Xv, yv }; return { Xv, yv };
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
pair<std::vector<std::vector<int>>&, std::vector<int>&> Dataset::getVectorsDiscretized() pair<std::vector<std::vector<int>>&, std::vector<int>&> Dataset::getVectorsDiscretized()
@@ -63,7 +64,7 @@ namespace platform {
if (loaded) { if (loaded) {
return { Xd, yv }; return { Xd, yv };
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors() pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
@@ -72,7 +73,7 @@ namespace platform {
buildTensors(); buildTensors();
return { X, y }; return { X, y };
} else { } else {
throw std::invalid_argument("Dataset not loaded."); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
void Dataset::load_csv() void Dataset::load_csv()

View File

@@ -4,7 +4,13 @@
namespace platform { namespace platform {
class Datasets { class Datasets {
public: public:
explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); }; explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algo = "none") : discretize(discretize), sfileType(sfileType), discretizer_algo(discretizer_algo)
{
if (discretizer_algo == "none" && discretize) {
throw std::runtime_error("Can't discretize without discretization algorithm");
}
load();
};
std::vector<std::string> getNames(); std::vector<std::string> getNames();
std::vector<std::string> getFeatures(const std::string& name) const; std::vector<std::string> getFeatures(const std::string& name) const;
int getNSamples(const std::string& name) const; int getNSamples(const std::string& name) const;
@@ -17,6 +23,7 @@ namespace platform {
std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name); std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name);
std::pair<std::vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(const std::string& name); std::pair<std::vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(const std::string& name);
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name); std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(const std::vector<int>& train_idx, const std::vector<int>& test_idx);
bool isDataset(const std::string& name) const; bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const; void loadDataset(const std::string& name) const;
std::string toString() const; std::string toString() const;
@@ -24,6 +31,7 @@ namespace platform {
std::string path; std::string path;
fileType_t fileType; fileType_t fileType;
std::string sfileType; std::string sfileType;
std::string discretizer_algo;
std::map<std::string, std::unique_ptr<Dataset>> datasets; std::map<std::string, std::unique_ptr<Dataset>> datasets;
bool discretize; bool discretize;
void load(); // Loads the list of datasets void load(); // Loads the list of datasets

View File

@@ -373,7 +373,8 @@ namespace platform {
MPI_Bcast(msg, tasks_size + 1, MPI_CHAR, config_mpi.manager, MPI_COMM_WORLD); MPI_Bcast(msg, tasks_size + 1, MPI_CHAR, config_mpi.manager, MPI_COMM_WORLD);
tasks = json::parse(msg); tasks = json::parse(msg);
delete[] msg; delete[] msg;
auto datasets = Datasets(config.discretize, Paths::datasets()); auto env = platform::DotEnv();
auto datasets = Datasets(config.discretize, Paths::datasets(), env.get("discretiz_algo"));
if (config_mpi.rank == config_mpi.manager) { if (config_mpi.rank == config_mpi.manager) {
// //
// 2a. Producer delivers the tasks to the consumers // 2a. Producer delivers the tasks to the consumers

View File

@@ -117,20 +117,19 @@ namespace platform {
{ {
auto datasets = Datasets(false, Paths::datasets()); // Never discretize here auto datasets = Datasets(false, Paths::datasets()); // Never discretize here
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); // -------------- auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName); // -------------- auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
auto labels = datasets.getLabels(fileName); auto labels = datasets.getLabels(fileName);
int num_classes = states[className].size() == 0 ? labels.size() : states[className].size(); int num_classes = labels.size();
if (!quiet) { if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
} }
// Prepare Result // Prepare Result
auto partial_result = PartialResult(); auto partial_result = PartialResult();
auto [values, counts] = at::_unique(y); partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes);
partial_result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
partial_result.setHyperparameters(hyperparameters.get(fileName)); partial_result.setHyperparameters(hyperparameters.get(fileName));
// Initialize results std::vectors // Initialize results std::vectors
int nResults = nfolds * static_cast<int>(randomSeeds.size()); int nResults = nfolds * static_cast<int>(randomSeeds.size());
@@ -170,18 +169,10 @@ namespace platform {
// Split train - test dataset // Split train - test dataset
train_timer.start(); train_timer.start();
auto [train, test] = fold->getFold(nfold); auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train); auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test);
auto test_t = torch::tensor(test); // Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva
auto X_train = X.index({ "...", train_t }); // una referencia al objeto dataset y trabajar directamente con él.
auto y_train = y.index({ train_t }); auto states = datasets.getStates(fileName);
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
if (discretized) {
// compute states too
// discretizer->fit(X_train, y_train);
// X_train = discretizer->transform(X_train);
// X_test = discretizer->transform(X_test);
}
if (generate_fold_files) if (generate_fold_files)
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test); generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
if (!quiet) if (!quiet)