From 5826702fc75d444131cfae31911074a2d084aded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 20 Mar 2024 12:01:57 +0100 Subject: [PATCH] Remove weights backup --- bayesnet/ensembles/BoostAODE.cc | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/bayesnet/ensembles/BoostAODE.cc b/bayesnet/ensembles/BoostAODE.cc index 19ab74f..c449b73 100644 --- a/bayesnet/ensembles/BoostAODE.cc +++ b/bayesnet/ensembles/BoostAODE.cc @@ -208,13 +208,9 @@ namespace bayesnet { // run out of features bool ascending = order_algorithm == Orders.ASC; std::mt19937 g{ 173 }; - torch::Tensor weights_backup; - // LOG_SCOPE_FUNCTION(INFO); - // LOG_F(INFO, "Train model..."); while (!finished) { // Step 1: Build ranking with mutual information auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted - //LOG_S(INFO) << "1:featureSelection.size: " << featureSelection.size() << " featuresUsed.size: " << featuresUsed.size(); VLOG_SCOPE_F(1, "featureSelection.size: %d featuresUsed.size: %d", featureSelection.size(), featuresUsed.size()); if (order_algorithm == Orders.RAND) { std::shuffle(featureSelection.begin(), featureSelection.end(), g); @@ -226,10 +222,8 @@ namespace bayesnet { ); int k = pow(2, tolerance); int counter = 0; // The model counter of the current pack - // LOG_S(INFO) << "k=" << k; VLOG_SCOPE_F(1, "k=%d", k); while (counter++ < k && featureSelection.size() > 0) { - // LOG_S(INFO) << "2:counter: " << counter << " numItemsPack: " << numItemsPack << " featureSelection.size: " << featureSelection.size(); VLOG_SCOPE_F(2, "counter: %d numItemsPack: %d featureSelection.size: %d", counter, numItemsPack, featureSelection.size()); auto feature = featureSelection[0]; featureSelection.erase(featureSelection.begin()); @@ -237,15 +231,10 @@ namespace bayesnet { model = std::make_unique(feature); model->fit(dataset, features, className, states, weights_); torch::Tensor ypred; - //LOG_S(INFO) << "2:Begin model predict"; ypred = model->predict(X_train); - //LOG_S(INFO) << "2:End model predict"; // Step 3.1: Compute the classifier amout of say - weights_backup = weights_.clone(); std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_); if (finished) { - weights_ = weights_backup.clone(); - // LOG_S(INFO) << "2:** epsilon_t > 0.5 **"; VLOG_SCOPE_F(2, "** epsilon_t > 0.5 **"); break; } @@ -257,23 +246,18 @@ namespace bayesnet { n_models++; } if (convergence && !finished) { - //LOG_S(INFO) << "3:Begin ensemble predict"; auto y_val_predict = predict(X_test); - //LOG_S(INFO) << "3:End ensemble predict"; double accuracy = (y_val_predict == y_test).sum().item() / (double)y_test.size(0); if (priorAccuracy == 0) { priorAccuracy = accuracy; - // LOG_S(INFO) << "3:First accuracyb_manage: " << std::to_string(priorAccuracy); VLOG_SCOPE_F(3, "First accuracy: %f", priorAccuracy); } else { delta = accuracy - priorAccuracy; } if (delta < convergence_threshold) { - // LOG_S(INFO) << "3:* tolerance: " << tolerance << " numItemsPack: " << numItemsPack << " delta: " << delta << " prior: " << priorAccuracy << " current: " << accuracy << std::endl; VLOG_SCOPE_F(3, "(delta=threshold) Reset. tolerance: %d numItemsPack: %d delta: %f prior: %f current: %f", tolerance, numItemsPack, delta, priorAccuracy, accuracy); tolerance = 0; // Reset the counter if the model performs better numItemsPack = 0; @@ -287,16 +271,13 @@ namespace bayesnet { if (tolerance > maxTolerance) { if (numItemsPack < n_models) { notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated"); - // LOG_S(INFO) << "4:Convergence threshold reached & " << numItemsPack << " models eliminated" << " of " << n_models << std::endl; VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models); - weights_ = weights_backup; for (int i = 0; i < numItemsPack; ++i) { significanceModels.pop_back(); models.pop_back(); n_models--; } } else { - // LOG_S(INFO) << "4:Convergence threshold reached & 0 models eliminated n_models=" << n_models << " numItemsPack=" << numItemsPack; VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack); notes.push_back("Convergence threshold reached & 0 models eliminated"); }