Add csv and R_dat files to platform

This commit is contained in:
Ricardo Montañana Gómez 2023-09-29 13:52:50 +02:00
parent db17c14042
commit bb423da42f
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 111 additions and 11 deletions

View File

@ -5,13 +5,25 @@
namespace platform {
void Datasets::load()
{
auto sd = SourceData(sfileType);
fileType = sd.getFileType();
path = sd.getPath();
ifstream catalog(path + "all.txt");
if (catalog.is_open()) {
string line;
while (getline(catalog, line)) {
if (line.empty() || line[0] == '#') {
continue;
}
vector<string> tokens = split(line, ',');
string name = tokens[0];
string className = tokens[1];
string className;
try {
className = tokens[1];
}
catch (exception e) {
className = "-1";
}
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType);
}
catalog.close();
@ -193,7 +205,9 @@ namespace platform {
getline(file, line);
vector<string> tokens = split(line, ',');
features = vector<string>(tokens.begin(), tokens.end() - 1);
className = tokens.back();
if (className == "-1") {
className = tokens.back();
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(vector<float>());
}
@ -231,6 +245,53 @@ namespace platform {
auto attributes = arff.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; });
}
vector<string> tokenize(string line)
{
vector<string> tokens;
for (auto i = 0; i < line.size(); ++i) {
if (line[i] == ' ' || line[i] == '\t' || line[i] == '\n') {
string token = line.substr(0, i);
tokens.push_back(token);
line.erase(line.begin(), line.begin() + i + 1);
i = 0;
while (line[i] == ' ' || line[i] == '\t' || line[i] == '\n')
line.erase(line.begin(), line.begin() + i + 1);
}
}
if (line.size() > 0) {
tokens.push_back(line);
}
return tokens;
}
void Dataset::load_rdata()
{
ifstream file(path + "/" + name + "_R.dat");
if (file.is_open()) {
string line;
getline(file, line);
line = ArffFiles::trim(line);
vector<string> tokens = tokenize(line);
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
if (className == "-1") {
className = ArffFiles::trim(tokens.back());
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(vector<float>());
}
while (getline(file, line)) {
tokens = tokenize(line);
// We have to skip the first token, which is the instance number.
for (auto i = 1; i < features.size() + 1; ++i) {
const float value = stof(tokens[i]);
Xv[i - 1].push_back(value);
}
yv.push_back(stoi(tokens.back()));
}
file.close();
} else {
throw invalid_argument("Unable to open dataset file.");
}
}
void Dataset::load()
{
if (loaded) {
@ -240,6 +301,8 @@ namespace platform {
load_csv();
} else if (fileType == ARFF) {
load_arff();
} else if (fileType == RDATA) {
load_rdata();
}
if (discretize) {
Xd = discretizeDataset(Xv, yv);

View File

@ -6,7 +6,36 @@
#include <string>
namespace platform {
using namespace std;
enum fileType_t { CSV, ARFF };
enum fileType_t { CSV, ARFF, RDATA };
class SourceData {
public:
SourceData(string source)
{
if (source == "Surcov") {
path = "datasets/";
fileType = CSV;
} else if (source == "Arff") {
path = "datasets/";
fileType = ARFF;
} else if (source == "Tanveer") {
path = "data/";
fileType = RDATA;
} else {
throw invalid_argument("Unknown source.");
}
}
string getPath()
{
return path;
}
fileType_t getFileType()
{
return fileType;
}
private:
string path;
fileType_t fileType;
};
class Dataset {
private:
string path;
@ -25,6 +54,7 @@ namespace platform {
void buildTensors();
void load_csv();
void load_arff();
void load_rdata();
void computeStates();
public:
Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {};
@ -45,11 +75,12 @@ namespace platform {
private:
string path;
fileType_t fileType;
string sfileType;
map<string, unique_ptr<Dataset>> datasets;
bool discretize;
void load(); // Loads the list of datasets
public:
explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); };
explicit Datasets(bool discretize, string sfileType) : discretize(discretize), sfileType(sfileType) { load(); };
vector<string> getNames();
vector<string> getFeatures(const string& name) const;
int getNSamples(const string& name) const;

View File

@ -1,8 +1,9 @@
#include <fstream>
#include "Experiment.h"
#include "Datasets.h"
#include "Models.h"
#include "ReportConsole.h"
#include <fstream>
#include "DotEnv.h"
namespace platform {
using json = nlohmann::json;
string get_date()
@ -133,7 +134,8 @@ namespace platform {
}
void Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, discretized, platform::ARFF);
auto env = platform::DotEnv();
auto datasets = platform::Datasets(discretized, env.get("source_data"));
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);

View File

@ -3,7 +3,7 @@
#include "Datasets.h"
#include "ReportBase.h"
#include "BestScore.h"
#include "DotEnv.h"
namespace platform {
ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
@ -58,7 +58,8 @@ namespace platform {
}
} else {
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
auto env = platform::DotEnv();
auto dt = Datasets(false, env.get("source_data"));
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {

View File

@ -3,6 +3,7 @@
#include "Paths.h"
#include "Colors.h"
#include "Datasets.h"
#include "DotEnv.h"
using namespace std;
const int BALANCE_LENGTH = 75;
@ -27,7 +28,8 @@ void outputBalance(const string& balance)
int main(int argc, char** argv)
{
auto data = platform::Datasets(platform::Paths().datasets(), false);
auto env = platform::DotEnv();
auto data = platform::Datasets(false, env.get("source_data"));
locale mylocale(cout.getloc(), new separated);
locale::global(mylocale);
cout.imbue(mylocale);

View File

@ -89,7 +89,8 @@ int main(int argc, char** argv)
auto seeds = program.get<vector<int>>("seeds");
auto hyperparameters = program.get<string>("hyperparameters");
vector<string> filesToTest;
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto env = platform::DotEnv();
auto datasets = platform::Datasets(discretize_dataset, env.get("source_data"));
auto title = program.get<string>("title");
auto saveResults = program.get<bool>("save");
if (file_name != "") {
@ -108,7 +109,7 @@ int main(int argc, char** argv)
/*
* Begin Processing
*/
auto env = platform::DotEnv();
auto experiment = platform::Experiment();
experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3");
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));