#include "Datasets.h" #include namespace platform { void Datasets::load() { auto sd = SourceData(sfileType); fileType = sd.getFileType(); path = sd.getPath(); ifstream catalog(path + "all.txt"); if (catalog.is_open()) { string line; while (getline(catalog, line)) { if (line.empty() || line[0] == '#') { continue; } vector tokens = split(line, ','); string name = tokens[0]; string className; if (tokens.size() == 1) { className = "-1"; } else { className = tokens[1]; } datasets[name] = make_unique(path, name, className, discretize, fileType); } catalog.close(); } else { throw invalid_argument("Unable to open catalog file. [" + path + "all.txt" + "]"); } } vector Datasets::getNames() { vector result; transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; }); return result; } vector Datasets::getFeatures(const string& name) const { if (datasets.at(name)->isLoaded()) { return datasets.at(name)->getFeatures(); } else { throw invalid_argument("Dataset not loaded."); } } map> Datasets::getStates(const string& name) const { if (datasets.at(name)->isLoaded()) { return datasets.at(name)->getStates(); } else { throw invalid_argument("Dataset not loaded."); } } void Datasets::loadDataset(const string& name) const { if (datasets.at(name)->isLoaded()) { return; } else { datasets.at(name)->load(); } } string Datasets::getClassName(const string& name) const { if (datasets.at(name)->isLoaded()) { return datasets.at(name)->getClassName(); } else { throw invalid_argument("Dataset not loaded."); } } int Datasets::getNSamples(const string& name) const { if (datasets.at(name)->isLoaded()) { return datasets.at(name)->getNSamples(); } else { throw invalid_argument("Dataset not loaded."); } } int Datasets::getNClasses(const string& name) { if (datasets.at(name)->isLoaded()) { auto className = datasets.at(name)->getClassName(); if (discretize) { auto states = getStates(name); return states.at(className).size(); } auto [Xv, yv] = getVectors(name); return *max_element(yv.begin(), yv.end()) + 1; } else { throw invalid_argument("Dataset not loaded."); } } vector Datasets::getClassesCounts(const string& name) const { if (datasets.at(name)->isLoaded()) { auto [Xv, yv] = datasets.at(name)->getVectors(); vector counts(*max_element(yv.begin(), yv.end()) + 1); for (auto y : yv) { counts[y]++; } return counts; } else { throw invalid_argument("Dataset not loaded."); } } pair>&, vector&> Datasets::getVectors(const string& name) { if (!datasets[name]->isLoaded()) { datasets[name]->load(); } return datasets[name]->getVectors(); } pair>&, vector&> Datasets::getVectorsDiscretized(const string& name) { if (!datasets[name]->isLoaded()) { datasets[name]->load(); } return datasets[name]->getVectorsDiscretized(); } pair Datasets::getTensors(const string& name) { if (!datasets[name]->isLoaded()) { datasets[name]->load(); } return datasets[name]->getTensors(); } bool Datasets::isDataset(const string& name) const { return datasets.find(name) != datasets.end(); } }