Complete refactor of XA1DE & XBAODE with new ExpClf class

This commit is contained in:
2025-02-26 16:55:04 +01:00
parent c63baf419f
commit 1a688f90b4
7 changed files with 319 additions and 402 deletions

View File

@@ -11,17 +11,16 @@
#include "XBAODE.h"
#include "TensorUtils.hpp"
#include <loguru.hpp>
#include <loguru.cpp>
namespace platform {
XBAODE::XBAODE() : semaphore_{ CountingSemaphore::getInstance() }, Boost(false)
XBAODE::XBAODE() : Boost(false)
{
validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
Boost::validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
"predict_voting", "select_features" };
}
void XBAODE::trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing)
{
fitted = true;
Boost::fitted = true;
X_train_ = TensorUtils::to_matrix(X_train);
y_train_ = TensorUtils::to_vector<int>(y_train);
X_test_ = TensorUtils::to_matrix(X_test);
@@ -40,18 +39,17 @@ namespace platform {
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool finished = false;
std::vector<int> featuresUsed;
significanceModels.resize(n, 0.0); // n possible spodes
aode_.fit(X_train_, y_train_, features, className, states, smoothing);
aode_.fit(X_train_, y_train_, features, className, states, weights_, false);
n_models = 0;
if (selectFeatures) {
featuresUsed = featureSelection(weights_);
aode_.set_active_parents(featuresUsed);
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
auto ypred = aode_.predict(X_train);
set_active_parents(featuresUsed);
Boost::notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
auto ypred = ExpClf::predict(X_train);
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
// Update significance of the models
for (const auto& parent : featuresUsed) {
significanceModels[parent] = alpha_t;
aode_.significance_models[parent] = alpha_t;
}
n_models = featuresUsed.size();
VLOG_SCOPE_F(1, "SelectFeatures. alpha_t: %f n_models: %d", alpha_t, n_models);
@@ -88,7 +86,7 @@ namespace platform {
while (counter++ < k && featureSelection.size() > 0) {
auto feature = featureSelection[0];
featureSelection.erase(featureSelection.begin());
aode_.add_active_parent(feature);
add_active_parent(feature);
alpha_t = 0.0;
std::vector<int> ypred;
if (alpha_block) {
@@ -97,16 +95,16 @@ namespace platform {
//
// Add the model to the ensemble
n_models++;
significanceModels[feature] = 1.0;
aode_.significance_models[feature] = 1.0;
aode_.add_active_parent(feature);
// Compute the prediction
ypred = aode_.predict(X_train_);
ypred = ExpClf::predict(X_train_);
// Remove the model from the ensemble
significanceModels[feature] = 0.0;
aode_.significance_models[feature] = 0.0;
aode_.remove_last_parent();
n_models--;
} else {
ypred = aode_.predict_spode(X_train_, feature);
ypred = predict_spode(X_train_, feature);
}
// Step 3.1: Compute the classifier amout of say
auto ypred_t = torch::tensor(ypred);
@@ -115,12 +113,12 @@ namespace platform {
numItemsPack++;
featuresUsed.push_back(feature);
aode_.add_active_parent(feature);
significanceModels.push_back(alpha_t);
aode_.significance_models[feature] = alpha_t;
n_models++;
VLOG_SCOPE_F(2, "finished: %d numItemsPack: %d n_models: %d featuresUsed: %zu", finished, numItemsPack, n_models, featuresUsed.size());
} // End of the pack
if (convergence && !finished) {
auto y_val_predict = predict(X_test);
auto y_val_predict = ExpClf::predict(X_test);
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
if (priorAccuracy == 0) {
priorAccuracy = accuracy;
@@ -148,79 +146,24 @@ namespace platform {
}
if (tolerance > maxTolerance) {
if (numItemsPack < n_models) {
notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
Boost::notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
for (int i = 0; i < numItemsPack; ++i) {
significanceModels.pop_back();
models.pop_back();
for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
aode_.remove_last_parent();
aode_.significance_models[featuresUsed[i]] = 0.0;
n_models--;
}
VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features used.", n_models, featuresUsed.size());
} else {
notes.push_back("Convergence threshold reached & 0 models eliminated");
Boost::notes.push_back("Convergence threshold reached & 0 models eliminated");
VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack);
}
}
if (featuresUsed.size() != features.size()) {
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
status = bayesnet::WARNING;
Boost::notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
Boost::status = bayesnet::WARNING;
}
notes.push_back("Number of models: " + std::to_string(n_models));
Boost::notes.push_back("Number of models: " + std::to_string(n_models));
return;
}
//
// Predict
//
std::vector<std::vector<double>> XBAODE::predict_proba(std::vector<std::vector<int>>& test_data)
{
return aode_.predict_proba_threads(test_data);
}
std::vector<int> XBAODE::predict(std::vector<std::vector<int>>& test_data)
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
return aode_.predict(test_data);
}
float XBAODE::score(std::vector<std::vector<int>>& test_data, std::vector<int>& labels)
{
return aode_.score(test_data, labels);
}
//
// statistics
//
int XBAODE::getNumberOfNodes() const
{
return aode_.getNumberOfNodes();
}
int XBAODE::getNumberOfEdges() const
{
return aode_.getNumberOfEdges();
}
int XBAODE::getNumberOfStates() const
{
return aode_.getNumberOfStates();
}
int XBAODE::getClassNumStates() const
{
return aode_.getClassNumStates();
}
//
// Predict
//
torch::Tensor XBAODE::predict(torch::Tensor& X)
{
return aode_.predict(X);
}
torch::Tensor XBAODE::predict_proba(torch::Tensor& X)
{
return aode_.predict_proba(X);
}
float XBAODE::score(torch::Tensor& X, torch::Tensor& y)
{
return aode_.score(X, y);
}
}