Complete Experiment
This commit is contained in:
@@ -28,13 +28,6 @@ namespace platform {
|
||||
throw invalid_argument("Unable to open catalog file. [" + path + "/all.txt" + "]");
|
||||
}
|
||||
}
|
||||
Dataset& Datasets::getDataset(string name)
|
||||
{
|
||||
if (datasets.find(name) == datasets.end()) {
|
||||
throw invalid_argument("Dataset not found.");
|
||||
}
|
||||
return *datasets[name];
|
||||
}
|
||||
vector<string> Datasets::getNames()
|
||||
{
|
||||
vector<string> result;
|
||||
@@ -45,45 +38,56 @@ namespace platform {
|
||||
}
|
||||
vector<string> Datasets::getFeatures(string name)
|
||||
{
|
||||
auto dataset = getDataset(name);
|
||||
if (dataset.isLoaded()) {
|
||||
return dataset.getFeatures();
|
||||
if (datasets[name]->isLoaded()) {
|
||||
return datasets[name]->getFeatures();
|
||||
} else {
|
||||
throw invalid_argument("Dataset not loaded.");
|
||||
}
|
||||
}
|
||||
map<string, vector<int>> Datasets::getStates(string name)
|
||||
{
|
||||
auto dataset = getDataset(name);
|
||||
if (dataset.isLoaded()) {
|
||||
return dataset.getStates();
|
||||
if (datasets[name]->isLoaded()) {
|
||||
return datasets[name]->getStates();
|
||||
} else {
|
||||
throw invalid_argument("Dataset not loaded.");
|
||||
}
|
||||
}
|
||||
string Datasets::getClassName(string name)
|
||||
{
|
||||
if (datasets[name]->isLoaded()) {
|
||||
return datasets[name]->getClassName();
|
||||
} else {
|
||||
throw invalid_argument("Dataset not loaded.");
|
||||
}
|
||||
}
|
||||
int Datasets::getNSamples(string name)
|
||||
{
|
||||
if (datasets[name]->isLoaded()) {
|
||||
return datasets[name]->getNSamples();
|
||||
} else {
|
||||
throw invalid_argument("Dataset not loaded.");
|
||||
}
|
||||
}
|
||||
pair<vector<vector<float>>&, vector<int>&> Datasets::getVectors(string name)
|
||||
{
|
||||
auto dataset = getDataset(name);
|
||||
if (!dataset.isLoaded()) {
|
||||
dataset.load();
|
||||
if (!datasets[name]->isLoaded()) {
|
||||
datasets[name]->load();
|
||||
}
|
||||
return dataset.getVectors();
|
||||
return datasets[name]->getVectors();
|
||||
}
|
||||
pair<vector<vector<int>>&, vector<int>&> Datasets::getVectorsDiscretized(string name)
|
||||
{
|
||||
auto dataset = getDataset(name);
|
||||
if (!dataset.isLoaded()) {
|
||||
dataset.load();
|
||||
if (!datasets[name]->isLoaded()) {
|
||||
datasets[name]->load();
|
||||
}
|
||||
return dataset.getVectorsDiscretized();
|
||||
return datasets[name]->getVectorsDiscretized();
|
||||
}
|
||||
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(string name)
|
||||
{
|
||||
auto dataset = getDataset(name);
|
||||
if (!dataset.isLoaded()) {
|
||||
dataset.load();
|
||||
if (!datasets[name]->isLoaded()) {
|
||||
datasets[name]->load();
|
||||
}
|
||||
return dataset.getTensors();
|
||||
return datasets[name]->getTensors();
|
||||
}
|
||||
Dataset::Dataset(Dataset& dataset)
|
||||
{
|
||||
@@ -195,11 +199,11 @@ namespace platform {
|
||||
void Dataset::computeStates()
|
||||
{
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
states[features[i]] = vector<int>(*max_element(Xd[i].begin(), Xd[i].end()));
|
||||
iota(Xd[i].begin(), Xd[i].end(), 0);
|
||||
states[features[i]] = vector<int>(*max_element(Xd[i].begin(), Xd[i].end()) + 1);
|
||||
iota(begin(states[features[i]]), end(states[features[i]]), 0);
|
||||
}
|
||||
states[className] = vector<int>(*max_element(yv.begin(), yv.end()));
|
||||
iota(yv.begin(), yv.end(), 0);
|
||||
states[className] = vector<int>(*max_element(yv.begin(), yv.end()) + 1);
|
||||
iota(begin(states[className]), end(states[className]), 0);
|
||||
}
|
||||
void Dataset::load_arff()
|
||||
{
|
||||
@@ -209,8 +213,7 @@ namespace platform {
|
||||
Xv = arff.getX();
|
||||
yv = arff.getY();
|
||||
// Get className & Features
|
||||
auto className = arff.getClassName();
|
||||
vector<string> features;
|
||||
className = arff.getClassName();
|
||||
for (auto feature : arff.getAttributes()) {
|
||||
features.push_back(feature.first);
|
||||
}
|
||||
@@ -246,7 +249,7 @@ namespace platform {
|
||||
} else {
|
||||
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
|
||||
}
|
||||
y = torch::tensor(yv, torch::kInt32);
|
||||
}
|
||||
y = torch::tensor(yv, torch::kInt32);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user