Implement the functionality of the hyperparameter alpha_block with test

This commit is contained in:
2024-12-14 14:02:45 +01:00
parent 18844c7da7
commit 3d6b4f0614
2 changed files with 36 additions and 3 deletions

View File

@@ -92,7 +92,25 @@ namespace bayesnet {
model->fit(dataset, features, className, states, weights_, smoothing);
alpha_t = 0.0;
if (!block_update) {
auto ypred = model->predict(X_train);
torch::Tensor ypred;
if (alpha_block) {
//
// Compute the prediction with the current ensemble + model
//
// Add the model to the ensemble
n_models++;
models.push_back(std::move(model));
significanceModels.push_back(1);
// Compute the prediction
ypred = predict(X_train);
// Remove the model from the ensemble
model = std::move(models.back());
models.pop_back();
significanceModels.pop_back();
n_models--;
} else {
ypred = model->predict(X_train);
}
// Step 3.1: Compute the classifier amout of say
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
}