Adding Datasets management
This commit is contained in:
63
src/Platform/Datasets.h
Normal file
63
src/Platform/Datasets.h
Normal file
@@ -0,0 +1,63 @@
|
||||
#ifndef DATASETS_H
|
||||
#define DATASETS_H
|
||||
#include <torch/torch.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
enum fileType_t { CSV, ARFF };
|
||||
class Dataset {
|
||||
private:
|
||||
string path;
|
||||
string name;
|
||||
fileType_t fileType;
|
||||
string className;
|
||||
int n_samples, n_features;
|
||||
vector<string> features;
|
||||
map<string, vector<int>> states;
|
||||
bool loaded;
|
||||
bool discretize;
|
||||
torch::Tensor X, y;
|
||||
vector<vector<float>> Xv;
|
||||
vector<vector<int>> Xd;
|
||||
vector<int> yv;
|
||||
void buildTensors();
|
||||
void load_csv();
|
||||
void load_arff();
|
||||
void computeStates();
|
||||
public:
|
||||
Dataset(string path, string name, string className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
|
||||
Dataset(Dataset&);
|
||||
string getName();
|
||||
string getClassName();
|
||||
vector<string> getFeatures();
|
||||
map<string, vector<int>> getStates();
|
||||
pair<vector<vector<float>>&, vector<int>&> getVectors();
|
||||
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized();
|
||||
pair<torch::Tensor&, torch::Tensor&> getTensors();
|
||||
int getNFeatures();
|
||||
int getNSamples();
|
||||
void load();
|
||||
const bool inline isLoaded() const { return loaded; };
|
||||
};
|
||||
class Datasets {
|
||||
private:
|
||||
string path;
|
||||
fileType_t fileType;
|
||||
map<string, unique_ptr<Dataset>> datasets;
|
||||
bool discretize;
|
||||
void load(); // Loads the list of datasets
|
||||
public:
|
||||
Datasets(string path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
|
||||
Dataset& getDataset(string name);
|
||||
vector<string> getNames();
|
||||
vector<string> getFeatures(string name);
|
||||
map<string, vector<int>> getStates(string name);
|
||||
pair<vector<vector<float>>&, vector<int>&> getVectors(string name);
|
||||
pair<vector<vector<int>>&, vector<int>&> getVectorsDiscretized(string name);
|
||||
pair<torch::Tensor&, torch::Tensor&> getTensors(string name);
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user