Update select features models significance

This commit is contained in:
2024-03-05 12:10:58 +01:00
parent 093c197f0a
commit 0ee3eaed53
2 changed files with 43 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
#include <set>
#include <functional>
#include <limits.h>
#include <tuple>
#include "BoostAODE.h"
#include "CFS.h"
#include "FCBF.h"
@@ -112,6 +113,33 @@ namespace bayesnet {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)
{
bool terminate = false;
double alpha_t = 0;
auto mask_wrong = ypred != ytrain;
auto mask_right = ypred == ytrain;
auto masked_weights = weights * mask_wrong.to(weights.dtype());
double epsilon_t = masked_weights.sum().item<double>();
if (epsilon_t > 0.5) {
// Inverse the weights policy (plot ln(wt))
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
terminate = true;
} else {
double wt = (1 - epsilon_t) / epsilon_t;
alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
// Step 3.2: Update weights for next classifier
// Step 3.2.1: Update weights of wrong samples
weights += mask_wrong.to(weights.dtype()) * exp(alpha_t) * weights;
// Step 3.2.2: Update weights of right samples
weights += mask_right.to(weights.dtype()) * exp(-alpha_t) * weights;
// Step 3.3: Normalise the weights
double totalWeights = torch::sum(weights).item<double>();
weights = weights / totalWeights;
}
return { weights, alpha_t, terminate };
}
std::unordered_set<int> BoostAODE::initializeModels()
{
std::unordered_set<int> featuresUsed;
@@ -161,19 +189,29 @@ namespace bayesnet {
{
initialize_prob_table = true;
fitted = true;
double alpha_t = 0;
// Algorithm based on the adaboost algorithm for classification
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false;
std::unordered_set<int> featuresUsed;
if (selectFeatures) {
featuresUsed = initializeModels();
auto ypred = predict(X_train);
std::tie(weights_, alpha_t, exitCondition) = update_weights(y_train, ypred, weights_);
// Update significance of the models
for (int i = 0; i < n_models; ++i) {
significanceModels[i] = alpha_t;
}
if (exitCondition) {
return;
}
}
bool resetMaxModels = false;
if (maxModels == 0) {
maxModels = .1 * n > 10 ? .1 * n : n;
resetMaxModels = true; // Flag to unset maxModels
}
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false;
// Variables to control the accuracy finish condition
double priorAccuracy = 0.0;
double delta = 1.0;
@@ -218,26 +256,10 @@ namespace bayesnet {
ypred = ensemble_predict(X_train, dynamic_cast<SPODE*>(model.get()));
}
// Step 3.1: Compute the classifier amout of say
auto mask_wrong = ypred != y_train;
auto mask_right = ypred == y_train;
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
double epsilon_t = masked_weights.sum().item<double>();
if (epsilon_t > 0.5) {
// Inverse the weights policy (plot ln(wt))
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
std::tie(weights_, alpha_t, exitCondition) = update_weights(y_train, ypred, weights_);
if (exitCondition) {
break;
}
double wt = (1 - epsilon_t) / epsilon_t;
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
// Step 3.2: Update weights for next classifier
// Step 3.2.1: Update weights of wrong samples
weights_ += mask_wrong.to(weights_.dtype()) * exp(alpha_t) * weights_;
// Step 3.2.2: Update weights of right samples
weights_ += mask_right.to(weights_.dtype()) * exp(-alpha_t) * weights_;
// Step 3.3: Normalise the weights
double totalWeights = torch::sum(weights_).item<double>();
weights_ = weights_ / totalWeights;
// Step 3.4: Store classifier and its accuracy to weigh its future vote
featuresUsed.insert(feature);
models.push_back(std::move(model));