Add discretiz algorithm management to b_main & Dataset

This commit is contained in:
2024-06-07 09:00:51 +02:00
parent 2202a81782
commit 5dd3deca1a
5 changed files with 28 additions and 27 deletions

View File

@@ -2,6 +2,7 @@
#include <fstream>
#include "Dataset.h"
namespace platform {
const std::string message_dataset_not_loaded = "Dataset not loaded.";
Dataset::Dataset(const Dataset& dataset) :
path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples),
n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features),
@@ -23,7 +24,7 @@ namespace platform {
if (loaded) {
return features;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Dataset::getNFeatures() const
@@ -31,7 +32,7 @@ namespace platform {
if (loaded) {
return n_features;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Dataset::getNSamples() const
@@ -39,7 +40,7 @@ namespace platform {
if (loaded) {
return n_samples;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::map<std::string, std::vector<int>> Dataset::getStates() const
@@ -47,7 +48,7 @@ namespace platform {
if (loaded) {
return states;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<std::vector<std::vector<float>>&, std::vector<int>&> Dataset::getVectors()
@@ -55,7 +56,7 @@ namespace platform {
if (loaded) {
return { Xv, yv };
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<std::vector<std::vector<int>>&, std::vector<int>&> Dataset::getVectorsDiscretized()
@@ -63,7 +64,7 @@ namespace platform {
if (loaded) {
return { Xd, yv };
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
@@ -72,7 +73,7 @@ namespace platform {
buildTensors();
return { X, y };
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
void Dataset::load_csv()

View File

@@ -4,7 +4,13 @@
namespace platform {
class Datasets {
public:
explicit Datasets(bool discretize, std::string sfileType) : discretize(discretize), sfileType(sfileType) { load(); };
explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algo = "none") : discretize(discretize), sfileType(sfileType), discretizer_algo(discretizer_algo)
{
if (discretizer_algo == "none" && discretize) {
throw std::runtime_error("Can't discretize without discretization algorithm");
}
load();
};
std::vector<std::string> getNames();
std::vector<std::string> getFeatures(const std::string& name) const;
int getNSamples(const std::string& name) const;
@@ -17,6 +23,7 @@ namespace platform {
std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name);
std::pair<std::vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(const std::string& name);
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(const std::vector<int>& train_idx, const std::vector<int>& test_idx);
bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const;
std::string toString() const;
@@ -24,6 +31,7 @@ namespace platform {
std::string path;
fileType_t fileType;
std::string sfileType;
std::string discretizer_algo;
std::map<std::string, std::unique_ptr<Dataset>> datasets;
bool discretize;
void load(); // Loads the list of datasets