Complete Experiment

This commit is contained in:
2023-07-27 15:49:58 +02:00
parent bc214a496c
commit 3d8fea7a37
6 changed files with 80 additions and 67 deletions

View File

@@ -28,13 +28,6 @@ namespace platform {
throw invalid_argument("Unable to open catalog file. [" + path + "/all.txt" + "]");
}
}
Dataset& Datasets::getDataset(string name)
{
if (datasets.find(name) == datasets.end()) {
throw invalid_argument("Dataset not found.");
}
return *datasets[name];
}
vector<string> Datasets::getNames()
{
vector<string> result;
@@ -45,45 +38,56 @@ namespace platform {
}
vector<string> Datasets::getFeatures(string name)
{
auto dataset = getDataset(name);
if (dataset.isLoaded()) {
return dataset.getFeatures();
if (datasets[name]->isLoaded()) {
return datasets[name]->getFeatures();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
map<string, vector<int>> Datasets::getStates(string name)
{
auto dataset = getDataset(name);
if (dataset.isLoaded()) {
return dataset.getStates();
if (datasets[name]->isLoaded()) {
return datasets[name]->getStates();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
string Datasets::getClassName(string name)
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getClassName();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
int Datasets::getNSamples(string name)
{
if (datasets[name]->isLoaded()) {
return datasets[name]->getNSamples();
} else {
throw invalid_argument("Dataset not loaded.");
}
}
pair<vector<vector<float>>&, vector<int>&> Datasets::getVectors(string name)
{
auto dataset = getDataset(name);
if (!dataset.isLoaded()) {
dataset.load();
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return dataset.getVectors();
return datasets[name]->getVectors();
}
pair<vector<vector<int>>&, vector<int>&> Datasets::getVectorsDiscretized(string name)
{
auto dataset = getDataset(name);
if (!dataset.isLoaded()) {
dataset.load();
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return dataset.getVectorsDiscretized();
return datasets[name]->getVectorsDiscretized();
}
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(string name)
{
auto dataset = getDataset(name);
if (!dataset.isLoaded()) {
dataset.load();
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return dataset.getTensors();
return datasets[name]->getTensors();
}
Dataset::Dataset(Dataset& dataset)
{
@@ -195,11 +199,11 @@ namespace platform {
void Dataset::computeStates()
{
for (int i = 0; i < features.size(); ++i) {
states[features[i]] = vector<int>(*max_element(Xd[i].begin(), Xd[i].end()));
iota(Xd[i].begin(), Xd[i].end(), 0);
states[features[i]] = vector<int>(*max_element(Xd[i].begin(), Xd[i].end()) + 1);
iota(begin(states[features[i]]), end(states[features[i]]), 0);
}
states[className] = vector<int>(*max_element(yv.begin(), yv.end()));
iota(yv.begin(), yv.end(), 0);
states[className] = vector<int>(*max_element(yv.begin(), yv.end()) + 1);
iota(begin(states[className]), end(states[className]), 0);
}
void Dataset::load_arff()
{
@@ -209,8 +213,7 @@ namespace platform {
Xv = arff.getX();
yv = arff.getY();
// Get className & Features
auto className = arff.getClassName();
vector<string> features;
className = arff.getClassName();
for (auto feature : arff.getAttributes()) {
features.push_back(feature.first);
}
@@ -246,7 +249,7 @@ namespace platform {
} else {
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
}
y = torch::tensor(yv, torch::kInt32);
}
y = torch::tensor(yv, torch::kInt32);
}
}