From 41257ed5661b629a6f142ff72f03953ea423e4b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sun, 10 Sep 2023 19:50:36 +0200 Subject: [PATCH] If ! convergence don't predict test --- src/BayesNet/BoostAODE.cc | 64 ++++++++++++++++++++++----------------- src/BayesNet/Network.cc | 6 ++-- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index ce9da7d..c976408 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -32,23 +32,31 @@ namespace bayesnet { void BoostAODE::validationInit() { auto y_ = dataset.index({ -1, "..." }); - auto fold = platform::StratifiedKFold(5, y_, 271); - dataset_ = torch::clone(dataset); - // save input dataset - auto [train, test] = fold.getFold(0); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - // Get train and validation sets - X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); - y_train = dataset.index({ -1, train_t }); - X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); - y_test = dataset.index({ -1, test_t }); - dataset = X_train; - m = X_train.size(1); - auto n_classes = states.at(className).size(); - metrics = Metrics(dataset, features, className, n_classes); - // Build dataset with train data - buildDataset(y_train); + if (convergence) { + // Prepare train & validation sets from train data + auto fold = platform::StratifiedKFold(5, y_, 271); + dataset_ = torch::clone(dataset); + // save input dataset + auto [train, test] = fold.getFold(0); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + // Get train and validation sets + X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); + y_train = dataset.index({ -1, train_t }); + X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); + y_test = dataset.index({ -1, test_t }); + dataset = X_train; + m = X_train.size(1); + auto n_classes = states.at(className).size(); + metrics = Metrics(dataset, features, className, n_classes); + // Build dataset with train data + buildDataset(y_train); + } else { + // Use all data to train + X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); + y_train = y_; + } + } void BoostAODE::trainModel(const torch::Tensor& weights) { @@ -64,7 +72,7 @@ namespace bayesnet { double priorAccuracy = 0.0; double delta = 1.0; double threshold = 1e-4; - int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold + int tolerance = 5; // number of times the accuracy can be lower than the threshold int count = 0; // number of times the accuracy is lower than the threshold fitted = true; // to enable predict // Step 0: Set the finish condition @@ -115,15 +123,17 @@ namespace bayesnet { models.push_back(std::move(model)); significanceModels.push_back(alpha_t); n_models++; - auto y_val_predict = predict(X_test); - double accuracy = (y_val_predict == y_test).sum().item() / (double)y_test.size(0); - if (priorAccuracy == 0) { - priorAccuracy = accuracy; - } else { - delta = accuracy - priorAccuracy; - } - if (delta < threshold) { - count++; + if (convergence) { + auto y_val_predict = predict(X_test); + double accuracy = (y_val_predict == y_test).sum().item() / (double)y_test.size(0); + if (priorAccuracy == 0) { + priorAccuracy = accuracy; + } else { + delta = accuracy - priorAccuracy; + } + if (delta < threshold) { + count++; + } } exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; } diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index b9fc659..bcf4301 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -132,9 +132,9 @@ namespace bayesnet { void Network::setStates(const map>& states) { // Set states to every Node in the network - for (int i = 0; i < features.size(); ++i) { - nodes.at(features.at(i))->setNumStates(states.at(features[i]).size()); - } + for_each(features.begin(), features.end(), [this, &states](const string& feature) { + nodes.at(feature)->setNumStates(states.at(feature).size()); + }); classNumStates = nodes.at(className)->getNumStates(); } // X comes in nxm, where n is the number of features and m the number of samples