Implement block update algorithm fix in BoostAODE

This commit is contained in:
Ricardo Montañana Gómez 2024-04-11 00:02:43 +02:00
parent cf9b5716ac
commit a2de1c9522
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 66 additions and 9 deletions

View File

@ -129,22 +129,67 @@ namespace bayesnet {
} }
std::tuple<torch::Tensor&, double, bool> BoostAODE::update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights) std::tuple<torch::Tensor&, double, bool> BoostAODE::update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights)
{ {
/* Update Block algorithm
k = # of models in block
n_models = # of models in ensemble to make predictions
n_models_bak = # models saved
models = vector of models to make predictions
models_bak = models not used to make predictions
significances_bak = backup of significances vector
Case list
A) k = 1, n_models = 1 => n = 0 , n_models = n + k
B) k = 1, n_models = n + 1 => n_models = n + k
C) k > 1, n_models = k + 1 => n= 1, n_models = n + k
D) k > 1, n_models = k => n = 0, n_models = n + k
E) k > 1, n_models = k + n => n_models = n + k
A, D) n=0, k > 0, n_models == k
1. n_models_bak <- n_models
2. significances_bak <- significances
3. significances = vector(k, 1)
4. Dont move any classifiers out of models
5. n_models <- k
6. Make prediction, compute alpha, update weights
7. Dont restore any classifiers to models
8. significances <- significances_bak
9. Update last k significances
10. n_models <- n_models_bak
B, C, E) n > 0, k > 0, n_models == n + k
1. n_models_bak <- n_models
2. significances_bak <- significances
3. significances = vector(k, 1)
4. Move first n classifiers to models_bak
5. n_models <- k
6. Make prediction, compute alpha, update weights
7. Insert classifiers in models_bak to be the first n models
8. significances <- significances_bak
9. Update last k significances
10. n_models <- n_models_bak
*/
// //
// Make predict with only the last k models // Make predict with only the last k models
// //
std::unique_ptr<Classifier> model; std::unique_ptr<Classifier> model;
std::vector<std::unique_ptr<Classifier>> models_bak; std::vector<std::unique_ptr<Classifier>> models_bak;
// 1. n_models_bak <- n_models 2. significances_bak <- significances
auto significance_bak = significanceModels; auto significance_bak = significanceModels;
auto n_models_bak = n_models; auto n_models_bak = n_models;
// Remove the first n_models - k models // 3. significances = vector(k, 1)
significanceModels = std::vector<double>(k, 1.0);
// 4. Move first n classifiers to models_bak
// backup the first n_models - k models (if n_models == k, don't backup any)
VLOG_SCOPE_F(1, "upd_weights_block n_models=%d k=%d", n_models, k);
for (int i = 0; i < n_models - k; ++i) { for (int i = 0; i < n_models - k; ++i) {
model = std::move(models[0]); model = std::move(models[0]);
models.erase(models.begin()); models.erase(models.begin());
models_bak.push_back(std::move(model)); models_bak.push_back(std::move(model));
} }
assert(models.size() == k); assert(models.size() == k);
significanceModels = std::vector<double>(k, 1.0); // 5. n_models <- k
n_models = k; n_models = k;
// 6. Make prediction, compute alpha, update weights
auto ypred = predict(X_train); auto ypred = predict(X_train);
// //
// Update weights // Update weights
@ -155,20 +200,28 @@ namespace bayesnet {
// //
// Restore the models if needed // Restore the models if needed
// //
// 7. Insert classifiers in models_bak to be the first n models
// if n_models_bak == k, don't restore any, because none of them were moved
if (k != n_models_bak) { if (k != n_models_bak) {
for (int i = k - 1; i >= 0; --i) { // Insert in the same order as they were extracted
model = std::move(models_bak[i]); int bak_size = models_bak.size();
for (int i = 0; i < bak_size; ++i) {
model = std::move(models_bak[bak_size - 1 - i]);
models_bak.erase(models_bak.end() - 1);
models.insert(models.begin(), std::move(model)); models.insert(models.begin(), std::move(model));
} }
} }
// 8. significances <- significances_bak
significanceModels = significance_bak; significanceModels = significance_bak;
n_models = n_models_bak;
// //
// Update the significance of the last k models // Update the significance of the last k models
// //
// 9. Update last k significances
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
significanceModels[n_models - k + i] = alpha_t; significanceModels[n_models_bak - k + i] = alpha_t;
} }
// 10. n_models <- n_models_bak
n_models = n_models_bak;
return { weights, alpha_t, terminate }; return { weights, alpha_t, terminate };
} }
std::vector<int> BoostAODE::initializeModels() std::vector<int> BoostAODE::initializeModels()
@ -265,9 +318,8 @@ namespace bayesnet {
); );
int k = pow(2, tolerance); int k = pow(2, tolerance);
int counter = 0; // The model counter of the current pack int counter = 0; // The model counter of the current pack
VLOG_SCOPE_F(1, "k=%d featureSelection.size: %zu", k, featureSelection.size()); VLOG_SCOPE_F(1, "counter=%d k=%d featureSelection.size: %zu", counter, k, featureSelection.size());
while (counter++ < k && featureSelection.size() > 0) { while (counter++ < k && featureSelection.size() > 0) {
VLOG_SCOPE_F(2, "counter: %d numItemsPack: %d", counter, numItemsPack);
auto feature = featureSelection[0]; auto feature = featureSelection[0];
featureSelection.erase(featureSelection.begin()); featureSelection.erase(featureSelection.begin());
std::unique_ptr<Classifier> model; std::unique_ptr<Classifier> model;

View File

@ -157,13 +157,18 @@ TEST_CASE("Bisection", "[BoostAODE]")
TEST_CASE("Block Update", "[BoostAODE]") TEST_CASE("Block Update", "[BoostAODE]")
{ {
auto clf = bayesnet::BoostAODE(); auto clf = bayesnet::BoostAODE();
auto raw = RawDatasets("mfeat-factors", true); // auto raw = RawDatasets("mfeat-factors", true);
auto raw = RawDatasets("glass", true);
clf.setHyperparameters({ clf.setHyperparameters({
{"bisection", true}, {"bisection", true},
{"block_update", true}, {"block_update", true},
{"maxTolerance", 3}, {"maxTolerance", 3},
{"convergence", true}, {"convergence", true},
}); });
// clf.setHyperparameters({
// {"block_update", true},
// });
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
REQUIRE(clf.getNumberOfNodes() == 217); REQUIRE(clf.getNumberOfNodes() == 217);
REQUIRE(clf.getNumberOfEdges() == 431); REQUIRE(clf.getNumberOfEdges() == 431);