diff --git a/src/Platform/Datasets.cc b/src/Platform/Datasets.cc index acf78e5..17b2ee1 100644 --- a/src/Platform/Datasets.cc +++ b/src/Platform/Datasets.cc @@ -5,13 +5,25 @@ namespace platform { void Datasets::load() { + auto sd = SourceData(sfileType); + fileType = sd.getFileType(); + path = sd.getPath(); ifstream catalog(path + "all.txt"); if (catalog.is_open()) { string line; while (getline(catalog, line)) { + if (line.empty() || line[0] == '#') { + continue; + } vector tokens = split(line, ','); string name = tokens[0]; - string className = tokens[1]; + string className; + try { + className = tokens[1]; + } + catch (exception e) { + className = "-1"; + } datasets[name] = make_unique(path, name, className, discretize, fileType); } catalog.close(); @@ -193,7 +205,9 @@ namespace platform { getline(file, line); vector tokens = split(line, ','); features = vector(tokens.begin(), tokens.end() - 1); - className = tokens.back(); + if (className == "-1") { + className = tokens.back(); + } for (auto i = 0; i < features.size(); ++i) { Xv.push_back(vector()); } @@ -231,6 +245,53 @@ namespace platform { auto attributes = arff.getAttributes(); transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& attribute) { return attribute.first; }); } + vector tokenize(string line) + { + vector tokens; + for (auto i = 0; i < line.size(); ++i) { + if (line[i] == ' ' || line[i] == '\t' || line[i] == '\n') { + string token = line.substr(0, i); + tokens.push_back(token); + line.erase(line.begin(), line.begin() + i + 1); + i = 0; + while (line[i] == ' ' || line[i] == '\t' || line[i] == '\n') + line.erase(line.begin(), line.begin() + i + 1); + } + } + if (line.size() > 0) { + tokens.push_back(line); + } + return tokens; + } + void Dataset::load_rdata() + { + ifstream file(path + "/" + name + "_R.dat"); + if (file.is_open()) { + string line; + getline(file, line); + line = ArffFiles::trim(line); + vector tokens = tokenize(line); + transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); }); + if (className == "-1") { + className = ArffFiles::trim(tokens.back()); + } + for (auto i = 0; i < features.size(); ++i) { + Xv.push_back(vector()); + } + while (getline(file, line)) { + tokens = tokenize(line); + // We have to skip the first token, which is the instance number. + for (auto i = 1; i < features.size() + 1; ++i) { + const float value = stof(tokens[i]); + Xv[i - 1].push_back(value); + } + yv.push_back(stoi(tokens.back())); + } + file.close(); + } else { + throw invalid_argument("Unable to open dataset file."); + } + } void Dataset::load() { if (loaded) { @@ -240,6 +301,8 @@ namespace platform { load_csv(); } else if (fileType == ARFF) { load_arff(); + } else if (fileType == RDATA) { + load_rdata(); } if (discretize) { Xd = discretizeDataset(Xv, yv); diff --git a/src/Platform/Datasets.h b/src/Platform/Datasets.h index a99c86e..aa3c109 100644 --- a/src/Platform/Datasets.h +++ b/src/Platform/Datasets.h @@ -6,7 +6,36 @@ #include namespace platform { using namespace std; - enum fileType_t { CSV, ARFF }; + enum fileType_t { CSV, ARFF, RDATA }; + class SourceData { + public: + SourceData(string source) + { + if (source == "Surcov") { + path = "datasets/"; + fileType = CSV; + } else if (source == "Arff") { + path = "datasets/"; + fileType = ARFF; + } else if (source == "Tanveer") { + path = "data/"; + fileType = RDATA; + } else { + throw invalid_argument("Unknown source."); + } + } + string getPath() + { + return path; + } + fileType_t getFileType() + { + return fileType; + } + private: + string path; + fileType_t fileType; + }; class Dataset { private: string path; @@ -25,6 +54,7 @@ namespace platform { void buildTensors(); void load_csv(); void load_arff(); + void load_rdata(); void computeStates(); public: Dataset(const string& path, const string& name, const string& className, bool discretize, fileType_t fileType) : path(path), name(name), className(className), discretize(discretize), loaded(false), fileType(fileType) {}; @@ -45,11 +75,12 @@ namespace platform { private: string path; fileType_t fileType; + string sfileType; map> datasets; bool discretize; void load(); // Loads the list of datasets public: - explicit Datasets(const string& path, bool discretize = false, fileType_t fileType = ARFF) : path(path), discretize(discretize), fileType(fileType) { load(); }; + explicit Datasets(bool discretize, string sfileType) : discretize(discretize), sfileType(sfileType) { load(); }; vector getNames(); vector getFeatures(const string& name) const; int getNSamples(const string& name) const; diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index c1d1048..f33e8d3 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -1,8 +1,9 @@ +#include #include "Experiment.h" #include "Datasets.h" #include "Models.h" #include "ReportConsole.h" -#include +#include "DotEnv.h" namespace platform { using json = nlohmann::json; string get_date() @@ -133,7 +134,8 @@ namespace platform { } void Experiment::cross_validation(const string& path, const string& fileName) { - auto datasets = platform::Datasets(path, discretized, platform::ARFF); + auto env = platform::DotEnv(); + auto datasets = platform::Datasets(discretized, env.get("source_data")); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index 3cac9d3..5f113a5 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -3,7 +3,7 @@ #include "Datasets.h" #include "ReportBase.h" #include "BestScore.h" - +#include "DotEnv.h" namespace platform { ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1) @@ -58,7 +58,8 @@ namespace platform { } } else { if (data["score_name"].get() == "accuracy") { - auto dt = Datasets(Paths::datasets(), false); + auto env = platform::DotEnv(); + auto dt = Datasets(false, env.get("source_data")); dt.loadDataset(dataset); auto numClasses = dt.getNClasses(dataset); if (numClasses == 2) { diff --git a/src/Platform/list.cc b/src/Platform/list.cc index ed8396d..8c386a5 100644 --- a/src/Platform/list.cc +++ b/src/Platform/list.cc @@ -3,6 +3,7 @@ #include "Paths.h" #include "Colors.h" #include "Datasets.h" +#include "DotEnv.h" using namespace std; const int BALANCE_LENGTH = 75; @@ -27,7 +28,8 @@ void outputBalance(const string& balance) int main(int argc, char** argv) { - auto data = platform::Datasets(platform::Paths().datasets(), false); + auto env = platform::DotEnv(); + auto data = platform::Datasets(false, env.get("source_data")); locale mylocale(cout.getloc(), new separated); locale::global(mylocale); cout.imbue(mylocale); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index a122ad2..ccd4271 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -89,7 +89,8 @@ int main(int argc, char** argv) auto seeds = program.get>("seeds"); auto hyperparameters = program.get("hyperparameters"); vector filesToTest; - auto datasets = platform::Datasets(path, true, platform::ARFF); + auto env = platform::DotEnv(); + auto datasets = platform::Datasets(discretize_dataset, env.get("source_data")); auto title = program.get("title"); auto saveResults = program.get("save"); if (file_name != "") { @@ -108,7 +109,7 @@ int main(int argc, char** argv) /* * Begin Processing */ - auto env = platform::DotEnv(); + auto experiment = platform::Experiment(); experiment.setTitle(title).setLanguage("cpp").setLanguageVersion("14.0.3"); experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));