Add list datasets and add locale format

This commit is contained in:
2023-08-19 19:05:16 +02:00
parent bafcb26bb6
commit 9972738deb
9 changed files with 190 additions and 46 deletions

View File

@@ -24,75 +24,110 @@ namespace platform {
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
return result;
}
vector<string> Datasets::getFeatures(string name)
vector<string> Datasets::getFeatures(const string& name) const
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getFeatures();
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getFeatures();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
map<string, vector<int>> Datasets::getStates(string name)
map<string, vector<int>> Datasets::getStates(const string& name) const
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getStates();
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getStates();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
string Datasets::getClassName(string name)
void Datasets::loadDataset(const string& name) const
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getClassName();
if (datasets.at(name)->isLoaded()) {
return;
} else {
datasets.at(name)->load();
}
}
string Datasets::getClassName(const string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getClassName();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
int Datasets::getNSamples(string name)
int Datasets::getNSamples(const string& name) const
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getNSamples();
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getNSamples();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
pair<vector<vector<float>>&, vector<int>&> Datasets::getVectors(string name)
int Datasets::getNClasses(const string& name)
{
if (datasets.at(name)->isLoaded()) {
auto className = datasets.at(name)->getClassName();
if (discretize) {
auto states = getStates(name);
return states.at(className).size();
}
auto [Xv, yv] = getVectors(name);
return *max_element(yv.begin(), yv.end()) + 1;
} else {
throw invalid_argument("Dataset not loaded.");
}
}
vector<int> Datasets::getClassesCounts(const string& name) const
{
if (datasets.at(name)->isLoaded()) {
auto [Xv, yv] = datasets.at(name)->getVectors();
vector<int> counts(*max_element(yv.begin(), yv.end()) + 1);
for (auto y : yv) {
counts[y]++;
}
return counts;
} else {
throw invalid_argument("Dataset not loaded.");
}
}
pair<vector<vector<float>>&, vector<int>&> Datasets::getVectors(const string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getVectors();
}
pair<vector<vector<int>>&, vector<int>&> Datasets::getVectorsDiscretized(string name)
pair<vector<vector<int>>&, vector<int>&> Datasets::getVectorsDiscretized(const string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getVectorsDiscretized();
}
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(string name)
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(const string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getTensors();
}
bool Datasets::isDataset(const string& name)
bool Datasets::isDataset(const string& name) const
{
return datasets.find(name) != datasets.end();
}
Dataset::Dataset(const Dataset& dataset) : path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples), n_features(dataset.n_features), features(dataset.features), states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv), fileType(dataset.fileType)
{
}
string Dataset::getName()
string Dataset::getName() const
{
return name;
}
string Dataset::getClassName()
string Dataset::getClassName() const
{
return className;
}
vector<string> Dataset::getFeatures()
vector<string> Dataset::getFeatures() const
{
if (loaded) {
return features;
@@ -100,7 +135,7 @@ namespace platform {
throw invalid_argument("Dataset not loaded.");
}
}
int Dataset::getNFeatures()
int Dataset::getNFeatures() const
{
if (loaded) {
return n_features;
@@ -108,7 +143,7 @@ namespace platform {
throw invalid_argument("Dataset not loaded.");
}
}
int Dataset::getNSamples()
int Dataset::getNSamples() const
{
if (loaded) {
return n_samples;
@@ -116,7 +151,7 @@ namespace platform {
throw invalid_argument("Dataset not loaded.");
}
}
map<string, vector<int>> Dataset::getStates()
map<string, vector<int>> Dataset::getStates() const
{
if (loaded) {
return states;