From adc0ca238f4d5b46081158f36b195f0f685352c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sat, 29 Jul 2023 16:44:07 +0200 Subject: [PATCH] Refactor cross_validation --- src/Platform/Experiment.cc | 84 +++++++++++++++++++++++--------------- src/Platform/Experiment.h | 4 +- src/Platform/main.cc | 16 ++------ 3 files changed, 55 insertions(+), 49 deletions(-) diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index b8c65b8..22a3c84 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -88,7 +88,7 @@ namespace platform { json data = build_json(); cout << data.dump(4) << endl; } - Result cross_validation(Fold* fold, string model_name, torch::Tensor& Xt, torch::Tensor& y, vector features, string className, map> states) + Result Experiment::cross_validation(string model_name, torch::Tensor& Xt, torch::Tensor& y, vector features, string className, map> states) { auto classifiers = map({ { "AODE", new bayesnet::AODE() }, { "KDB", new bayesnet::KDB(2) }, @@ -98,41 +98,57 @@ namespace platform { auto result = Result(); auto [values, counts] = at::_unique(y); result.setSamples(Xt.size(1)).setFeatures(Xt.size(0)).setClasses(values.size(0)); - auto k = fold->getNumberOfFolds(); - auto accuracy_test = torch::zeros({ k }, torch::kFloat64); - auto accuracy_train = torch::zeros({ k }, torch::kFloat64); - auto train_time = torch::zeros({ k }, torch::kFloat64); - auto test_time = torch::zeros({ k }, torch::kFloat64); - auto nodes = torch::zeros({ k }, torch::kFloat64); - auto edges = torch::zeros({ k }, torch::kFloat64); - auto num_states = torch::zeros({ k }, torch::kFloat64); + int nResults = nfolds * static_cast(randomSeeds.size()); + auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); + auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64); + auto train_time = torch::zeros({ nResults }, torch::kFloat64); + auto test_time = torch::zeros({ nResults }, torch::kFloat64); + auto nodes = torch::zeros({ nResults }, torch::kFloat64); + auto edges = torch::zeros({ nResults }, torch::kFloat64); + auto num_states = torch::zeros({ nResults }, torch::kFloat64); Timer train_timer, test_timer; - cout << "doing Fold: " << flush; - for (int i = 0; i < k; i++) { - bayesnet::BaseClassifier* model = classifiers[model_name]; - train_timer.start(); - auto [train, test] = fold->getFold(i); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - auto X_train = Xt.index({ "...", train_t }); - auto y_train = y.index({ train_t }); - auto X_test = Xt.index({ "...", test_t }); - auto y_test = y.index({ test_t }); - cout << i + 1 << ", " << flush; - model->fit(X_train, y_train, features, className, states); - nodes[i] = model->getNumberOfNodes(); - edges[i] = model->getNumberOfEdges(); - num_states[i] = model->getNumberOfStates(); - train_time[i] = train_timer.getDuration(); - auto accuracy_train_value = model->score(X_train, y_train); - test_timer.start(); - auto accuracy_test_value = model->score(X_test, y_test); - test_time[i] = test_timer.getDuration(); - accuracy_train[i] = accuracy_train_value; - accuracy_test[i] = accuracy_test_value; - + int item = 0; + for (auto seed : randomSeeds) { + cout << "(" << seed << ") " << flush; + Fold* fold; + if (stratified) + fold = new StratifiedKFold(nfolds, y, seed); + else + fold = new KFold(nfolds, y.size(0), seed); + cout << "doing Fold: " << flush; + for (int nfold = 0; nfold < nfolds; nfold++) { + bayesnet::BaseClassifier* clf = classifiers[model]; + setModelVersion(clf->getVersion()); + train_timer.start(); + auto [train, test] = fold->getFold(nfold); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + auto X_train = Xt.index({ "...", train_t }); + auto y_train = y.index({ train_t }); + auto X_test = Xt.index({ "...", test_t }); + auto y_test = y.index({ test_t }); + cout << nfold + 1 << ", " << flush; + clf->fit(X_train, y_train, features, className, states); + nodes[item] = clf->getNumberOfNodes(); + edges[item] = clf->getNumberOfEdges(); + num_states[item] = clf->getNumberOfStates(); + train_time[item] = train_timer.getDuration(); + auto accuracy_train_value = clf->score(X_train, y_train); + test_timer.start(); + auto accuracy_test_value = clf->score(X_test, y_test); + test_time[item] = test_timer.getDuration(); + accuracy_train[item] = accuracy_train_value; + accuracy_test[item] = accuracy_test_value; + // Store results and times in vector + result.addScoreTrain(accuracy_train_value); + result.addScoreTest(accuracy_test_value); + result.addTimeTrain(train_time[item].item()); + result.addTimeTest(test_time[item].item()); + item++; + } + cout << "end. " << flush; + delete fold; } - cout << "end. " << flush; result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); diff --git a/src/Platform/Experiment.h b/src/Platform/Experiment.h index 58639d7..a45113d 100644 --- a/src/Platform/Experiment.h +++ b/src/Platform/Experiment.h @@ -105,9 +105,9 @@ namespace platform { Experiment& setDuration(float duration) { this->duration = duration; return *this; } string get_file_name(); void save(string path); - Result cross_validation(const string& path, const string& fileName); + //Result cross_validation(const string& path, const string& fileName); + Result cross_validation(string model_name, torch::Tensor& X, torch::Tensor& y, vector features, string className, map> states); void show(); }; - Result cross_validation(Fold* fold, string model_name, torch::Tensor& X, torch::Tensor& y, vector features, string className, map> states); } #endif \ No newline at end of file diff --git a/src/Platform/main.cc b/src/Platform/main.cc index b7be5b5..1c6897d 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -126,19 +126,9 @@ int main(int argc, char** argv) auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; - for (auto seed : seeds) { - cout << "(" << seed << ") " << flush; - Fold* fold; - if (stratified) - fold = new StratifiedKFold(n_folds, y, seed); - else - fold = new KFold(n_folds, samples, seed); - auto result = platform::cross_validation(fold, model_name, X, y, features, className, states); - result.setDataset(fileName); - experiment.setModelVersion("-FIXME-"); - experiment.addResult(result); - delete fold; - } + auto result = experiment.cross_validation(model_name, X, y, features, className, states); + result.setDataset(fileName); + experiment.addResult(result); cout << endl; } experiment.setDuration(timer.getDuration());