From c4f3e6f19a5fd2612954fcd09e3e4ed8645f0dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sat, 29 Jul 2023 16:49:06 +0200 Subject: [PATCH] Refactor crossvalidation to remove unneeded params --- src/Platform/Experiment.cc | 20 +++++++++++++++----- src/Platform/Experiment.h | 3 +-- src/Platform/main.cc | 8 +------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 22a3c84..8592019 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -1,4 +1,5 @@ #include "Experiment.h" +#include "Datasets.h" namespace platform { using json = nlohmann::json; @@ -88,16 +89,25 @@ namespace platform { json data = build_json(); cout << data.dump(4) << endl; } - Result Experiment::cross_validation(string model_name, torch::Tensor& Xt, torch::Tensor& y, vector features, string className, map> states) + Result Experiment::cross_validation(const string& path, const string& fileName) { auto classifiers = map({ { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, { "SPODE", new bayesnet::SPODE(2) }, { "TAN", new bayesnet::TAN() } } ); + auto datasets = platform::Datasets(path, true, platform::ARFF); + // Get dataset + auto [X, y] = datasets.getTensors(fileName); + auto states = datasets.getStates(fileName); + auto features = datasets.getFeatures(fileName); + auto samples = datasets.getNSamples(fileName); + auto className = datasets.getClassName(fileName); + cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; + // Prepare Result auto result = Result(); - auto [values, counts] = at::_unique(y); - result.setSamples(Xt.size(1)).setFeatures(Xt.size(0)).setClasses(values.size(0)); + auto [values, counts] = at::_unique(y);; + result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); int nResults = nfolds * static_cast(randomSeeds.size()); auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64); @@ -123,9 +133,9 @@ namespace platform { auto [train, test] = fold->getFold(nfold); auto train_t = torch::tensor(train); auto test_t = torch::tensor(test); - auto X_train = Xt.index({ "...", train_t }); + auto X_train = X.index({ "...", train_t }); auto y_train = y.index({ train_t }); - auto X_test = Xt.index({ "...", test_t }); + auto X_test = X.index({ "...", test_t }); auto y_test = y.index({ test_t }); cout << nfold + 1 << ", " << flush; clf->fit(X_train, y_train, features, className, states); diff --git a/src/Platform/Experiment.h b/src/Platform/Experiment.h index a45113d..84b1627 100644 --- a/src/Platform/Experiment.h +++ b/src/Platform/Experiment.h @@ -105,8 +105,7 @@ namespace platform { Experiment& setDuration(float duration) { this->duration = duration; return *this; } string get_file_name(); void save(string path); - //Result cross_validation(const string& path, const string& fileName); - Result cross_validation(string model_name, torch::Tensor& X, torch::Tensor& y, vector features, string className, map> states); + Result cross_validation(const string& path, const string& fileName); void show(); }; } diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 1c6897d..d7d040d 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -120,13 +120,7 @@ int main(int argc, char** argv) timer.start(); for (auto fileName : filesToProcess) { cout << "- " << setw(20) << left << fileName << " " << right << flush; - auto [X, y] = datasets.getTensors(fileName); - auto states = datasets.getStates(fileName); - auto features = datasets.getFeatures(fileName); - auto samples = datasets.getNSamples(fileName); - auto className = datasets.getClassName(fileName); - cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; - auto result = experiment.cross_validation(model_name, X, y, features, className, states); + auto result = experiment.cross_validation(path, fileName); result.setDataset(fileName); experiment.addResult(result); cout << endl;