Add numeric features management to Dataset

This commit is contained in:
2024-06-06 13:03:57 +02:00
parent 6858b3d89a
commit a7ec930fa0
15 changed files with 210 additions and 43 deletions

View File

@@ -2,7 +2,12 @@
#include <fstream>
#include "Dataset.h"
namespace platform {
Dataset::Dataset(const Dataset& dataset) : path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), n_features(dataset.n_features), features(dataset.features), states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv), fileType(dataset.fileType)
Dataset::Dataset(const Dataset& dataset) :
path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples),
n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features),
states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y),
X_train(dataset.X_train), X_test(dataset.X_test), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv),
fileType(dataset.fileType)
{
}
std::string Dataset::getName() const
@@ -180,12 +185,20 @@ namespace platform {
} else if (fileType == RDATA) {
load_rdata();
}
n_samples = Xv[0].size();
n_features = Xv.size();
if (numericFeaturesIdx.at(0) == -1) {
numericFeatures = std::vector<bool>(n_features, true);
} else {
numericFeatures = std::vector<bool>(n_features, false);
for (auto i : numericFeaturesIdx) {
numericFeatures[i] = true;
}
}
if (discretize) {
Xd = discretizeDataset(Xv, yv);
computeStates();
}
n_samples = Xv[0].size();
n_features = Xv.size();
loaded = true;
}
void Dataset::buildTensors()
@@ -215,4 +228,9 @@ namespace platform {
}
return Xd;
}
std::pair <torch::Tensor&, torch::Tensor&> Dataset::getDiscretizedTrainTestTensors()
{
auto discretizer = Discretization::instance()->create("mdlp");
return { X_train, X_test };
}
}

View File

@@ -4,14 +4,17 @@
#include <map>
#include <vector>
#include <string>
#include <CPPFImdlp.h>
#include <common/DiscretizationRegister.h>
#include "Utils.h"
#include "SourceData.h"
namespace platform {
class Dataset {
public:
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx) :
path(path), name(name), className(className), discretize(discretize),
loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx)
{
};
explicit Dataset(const Dataset&);
std::string getName() const;
std::string getClassName() const;
@@ -20,9 +23,11 @@ namespace platform {
std::map<std::string, std::vector<int>> getStates() const;
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
std::pair<torch::Tensor&, torch::Tensor&> getDiscretizedTrainTestTensors();
std::pair<torch::Tensor&, torch::Tensor&> getTensors();
int getNFeatures() const;
int getNSamples() const;
std::vector<bool>& getNumericFeatures() { return numericFeatures; }
void load();
const bool inline isLoaded() const { return loaded; };
private:
@@ -31,12 +36,15 @@ namespace platform {
fileType_t fileType;
std::string className;
int n_samples{ 0 }, n_features{ 0 };
std::vector<int> numericFeaturesIdx;
std::vector<bool> numericFeatures; // true if feature is numeric
std::vector<std::string> features;
std::vector<std::string> labels;
std::map<std::string, std::vector<int>> states;
bool loaded;
bool discretize;
torch::Tensor X, y;
torch::Tensor X_train, X_test;
std::vector<std::vector<float>> Xv;
std::vector<std::vector<int>> Xd;
std::vector<int> yv;

View File

@@ -1,27 +1,47 @@
#include <fstream>
#include "Datasets.h"
#include <nlohmann/json.hpp>
namespace platform {
using json = nlohmann::ordered_json;
const std::string message_dataset_not_loaded = "dataset not loaded.";
void Datasets::load()
{
auto sd = SourceData(sfileType);
fileType = sd.getFileType();
path = sd.getPath();
ifstream catalog(path + "all.txt");
std::vector<int> numericFeaturesIdx;
if (catalog.is_open()) {
std::string line;
while (getline(catalog, line)) {
if (line.empty() || line[0] == '#') {
continue;
}
std::vector<std::string> tokens = split(line, ',');
std::vector<std::string> tokens = split(line, ';');
std::string name = tokens[0];
std::string className;
numericFeaturesIdx.clear();
if (tokens.size() == 1) {
className = "-1";
numericFeaturesIdx.push_back(-1);
} else {
className = tokens[1];
if (tokens.size() > 2) {
auto numericFeatures = tokens[2];
if (numericFeatures == "all") {
numericFeaturesIdx.push_back(-1);
} else {
auto features = json::parse(numericFeatures);
for (auto& f : features) {
numericFeaturesIdx.push_back(f);
}
}
} else {
numericFeaturesIdx.push_back(-1);
}
}
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType);
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx);
}
catalog.close();
} else {
@@ -39,7 +59,7 @@ namespace platform {
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getFeatures();
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<std::string> Datasets::getLabels(const std::string& name) const
@@ -47,7 +67,7 @@ namespace platform {
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getLabels();
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
map<std::string, std::vector<int>> Datasets::getStates(const std::string& name) const
@@ -55,7 +75,7 @@ namespace platform {
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getStates();
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
void Datasets::loadDataset(const std::string& name) const
@@ -71,7 +91,7 @@ namespace platform {
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getClassName();
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Datasets::getNSamples(const std::string& name) const
@@ -79,7 +99,7 @@ namespace platform {
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getNSamples();
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Datasets::getNClasses(const std::string& name)
@@ -93,7 +113,15 @@ namespace platform {
auto [Xv, yv] = getVectors(name);
return *std::max_element(yv.begin(), yv.end()) + 1;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<bool>& Datasets::getNumericFeatures(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getNumericFeatures();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<int> Datasets::getClassesCounts(const std::string& name) const
@@ -106,7 +134,7 @@ namespace platform {
}
return counts;
} else {
throw std::invalid_argument("Dataset not loaded.");
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<std::vector<std::vector<float>>&, std::vector<int>&> Datasets::getVectors(const std::string& name)

View File

@@ -11,6 +11,7 @@ namespace platform {
std::vector<std::string> getLabels(const std::string& name) const;
std::string getClassName(const std::string& name) const;
int getNClasses(const std::string& name);
std::vector<bool>& getNumericFeatures(const std::string& name) const;
std::vector<int> getClassesCounts(const std::string& name) const;
std::map<std::string, std::vector<int>> getStates(const std::string& name) const;
std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name);

View File

@@ -0,0 +1,55 @@
#include "Discretization.h"
namespace platform {
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
Discretization* Discretization::factory = nullptr;
Discretization* Discretization::instance()
{
//manages singleton
if (factory == nullptr)
factory = new Discretization();
return factory;
}
void Discretization::registerFactoryFunction(const std::string& name,
function<mdlp::Discretizer* (void)> classFactoryFunction)
{
// register the class factory function
functionRegistry[name] = classFactoryFunction;
}
std::shared_ptr<mdlp::Discretizer> Discretization::create(const std::string& name)
{
mdlp::Discretizer* instance = nullptr;
// find name in the registry and call factory method.
auto it = functionRegistry.find(name);
if (it != functionRegistry.end())
instance = it->second();
// wrap instance in a shared ptr and return
if (instance != nullptr)
return std::unique_ptr<mdlp::Discretizer>(instance);
else
throw std::runtime_error("Discretizer not found: " + name);
}
std::vector<std::string> Discretization::getNames()
{
std::vector<std::string> names;
transform(functionRegistry.begin(), functionRegistry.end(), back_inserter(names),
[](const pair<std::string, function<mdlp::Discretizer* (void)>>& pair) { return pair.first; });
return names;
}
std::string Discretization::toString()
{
std::string result = "";
std::string sep = "";
for (const auto& pair : functionRegistry) {
result += sep + pair.first;
sep = ", ";
}
return "{" + result + "}";
}
RegistrarDiscretization::RegistrarDiscretization(const std::string& name, function<mdlp::Discretizer* (void)> classFactoryFunction)
{
// register the class factory function
Discretization::instance()->registerFactoryFunction(name, classFactoryFunction);
}
}

View File

@@ -0,0 +1,33 @@
#ifndef DISCRETIZATION_H
#define DISCRETIZATION_H
#include <map>
#include <memory>
#include <string>
#include <functional>
#include <vector>
#include <Discretizer.h>
#include <BinDisc.h>
#include <CPPFImdlp.h>
namespace platform {
class Discretization {
public:
Discretization(Discretization&) = delete;
void operator=(const Discretization&) = delete;
// Idea from: https://www.codeproject.com/Articles/567242/AplusC-2b-2bplusObjectplusFactory
static Discretization* instance();
std::shared_ptr<mdlp::Discretizer> create(const std::string& name);
void registerFactoryFunction(const std::string& name,
function<mdlp::Discretizer* (void)> classFactoryFunction);
std::vector<string> getNames();
std::string toString();
private:
map<std::string, function<mdlp::Discretizer* (void)>> functionRegistry;
static Discretization* factory; //singleton
Discretization() {};
};
class RegistrarDiscretization {
public:
RegistrarDiscretization(const std::string& className, function<mdlp::Discretizer* (void)> classFactoryFunction);
};
}
#endif

View File

@@ -0,0 +1,10 @@
#ifndef DISCRETIZATIONREGISTER_H
#define DISCRETIZATIONREGISTER_H
#include <common/Discretization.h>
static platform::RegistrarDiscretization registrarM("mdlp",
[](void) -> mdlp::Discretizer* { return new mdlp::CPPFImdlp();});
static platform::RegistrarDiscretization registrarBU("BinUniform",
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(3, mdlp::strategy_t::UNIFORM);});
static platform::RegistrarDiscretization registrarBQ("BinQuantile",
[](void) -> mdlp::Discretizer* { return new mdlp::BinDisc(3, mdlp::strategy_t::QUANTILE);});
#endif

View File

@@ -3,17 +3,8 @@
#include <sstream>
#include <string>
#include <vector>
#include <algorithm>
namespace platform {
static std::vector<std::string> split(const std::string& text, char delimiter)
{
std::vector<std::string> result;
std::stringstream ss(text);
std::string token;
while (std::getline(ss, token, delimiter)) {
result.push_back(token);
}
return result;
}
static std::string trim(const std::string& str)
{
std::string result = str;
@@ -25,5 +16,15 @@ namespace platform {
}).base(), result.end());
return result;
}
static std::vector<std::string> split(const std::string& text, char delimiter)
{
std::vector<std::string> result;
std::stringstream ss(text);
std::string token;
while (std::getline(ss, token, delimiter)) {
result.push_back(trim(token));
}
return result;
}
}
#endif