Refactor crossvalidation to remove unneeded params
This commit is contained in:
parent
adc0ca238f
commit
c4f3e6f19a
@ -1,4 +1,5 @@
|
|||||||
#include "Experiment.h"
|
#include "Experiment.h"
|
||||||
|
#include "Datasets.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
@ -88,16 +89,25 @@ namespace platform {
|
|||||||
json data = build_json();
|
json data = build_json();
|
||||||
cout << data.dump(4) << endl;
|
cout << data.dump(4) << endl;
|
||||||
}
|
}
|
||||||
Result Experiment::cross_validation(string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
|
Result Experiment::cross_validation(const string& path, const string& fileName)
|
||||||
{
|
{
|
||||||
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
auto classifiers = map<string, bayesnet::BaseClassifier*>({
|
||||||
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
{ "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) },
|
||||||
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
{ "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() }
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
|
// Get dataset
|
||||||
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
|
auto states = datasets.getStates(fileName);
|
||||||
|
auto features = datasets.getFeatures(fileName);
|
||||||
|
auto samples = datasets.getNSamples(fileName);
|
||||||
|
auto className = datasets.getClassName(fileName);
|
||||||
|
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||||
|
// Prepare Result
|
||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);
|
auto [values, counts] = at::_unique(y);;
|
||||||
result.setSamples(Xt.size(1)).setFeatures(Xt.size(0)).setClasses(values.size(0));
|
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
@ -123,9 +133,9 @@ namespace platform {
|
|||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
auto X_train = Xt.index({ "...", train_t });
|
auto X_train = X.index({ "...", train_t });
|
||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = Xt.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
cout << nfold + 1 << ", " << flush;
|
cout << nfold + 1 << ", " << flush;
|
||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
@ -105,8 +105,7 @@ namespace platform {
|
|||||||
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
|
||||||
string get_file_name();
|
string get_file_name();
|
||||||
void save(string path);
|
void save(string path);
|
||||||
//Result cross_validation(const string& path, const string& fileName);
|
Result cross_validation(const string& path, const string& fileName);
|
||||||
Result cross_validation(string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
|
|
||||||
void show();
|
void show();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -120,13 +120,7 @@ int main(int argc, char** argv)
|
|||||||
timer.start();
|
timer.start();
|
||||||
for (auto fileName : filesToProcess) {
|
for (auto fileName : filesToProcess) {
|
||||||
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
cout << "- " << setw(20) << left << fileName << " " << right << flush;
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto result = experiment.cross_validation(path, fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
|
||||||
auto features = datasets.getFeatures(fileName);
|
|
||||||
auto samples = datasets.getNSamples(fileName);
|
|
||||||
auto className = datasets.getClassName(fileName);
|
|
||||||
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
|
||||||
auto result = experiment.cross_validation(model_name, X, y, features, className, states);
|
|
||||||
result.setDataset(fileName);
|
result.setDataset(fileName);
|
||||||
experiment.addResult(result);
|
experiment.addResult(result);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user