Continue refactoring

This commit is contained in:
2025-02-27 09:57:40 +01:00
parent 4e3043b2d1
commit f51d5b5e40
5 changed files with 28 additions and 12 deletions

View File

@@ -13,6 +13,22 @@ namespace platform {
validHyperparameters = {};
}
//
// Parents
//
void ExpClf::add_active_parents(const std::vector<int>& active_parents)
{
for (const auto& parent : active_parents)
aode_.add_active_parent(parent);
}
void ExpClf::add_active_parent(int parent)
{
aode_.add_active_parent(parent);
}
void ExpClf::remove_last_parent()
{
aode_.remove_last_parent();
}
//
// Predict
//
std::vector<int> ExpClf::predict_spode(std::vector<std::vector<int>>& test_data, int parent)

View File

@@ -39,9 +39,9 @@ namespace platform {
bayesnet::status_t getStatus() const override { return status; }
std::vector<std::string> getNotes() const override { return notes; }
std::vector<std::string> graph(const std::string& title = "") const override { return {}; }
void set_active_parents(const std::vector<int>& active_parents) { for (const auto& parent : active_parents) aode_.add_active_parent(parent); }
void add_active_parent(int parent) { aode_.add_active_parent(parent); }
void remove_last_parent() { aode_.remove_last_parent(); }
void add_active_parents(const std::vector<int>& active_parents);
void add_active_parent(int parent);
void remove_last_parent();
protected:
bool debug = false;
Xaode aode_;

View File

@@ -13,7 +13,7 @@ namespace platform {
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
int num_instances = X[0].size();
weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
weights_ = weights;
normalize_weights(num_instances);
aode_.fit(X, y, features, className, states, weights_, true);
}

View File

@@ -43,13 +43,13 @@ namespace platform {
n_models = 0;
if (selectFeatures) {
featuresUsed = featureSelection(weights_);
set_active_parents(featuresUsed);
add_active_parents(featuresUsed);
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
auto ypred = ExpClf::predict(X_train);
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
// Update significance of the models
for (const auto& parent : featuresUsed) {
aode_.significance_models[parent] = alpha_t;
aode_.significance_models_[parent] = alpha_t;
}
n_models = featuresUsed.size();
VLOG_SCOPE_F(1, "SelectFeatures. alpha_t: %f n_models: %d", alpha_t, n_models);
@@ -95,12 +95,12 @@ namespace platform {
//
// Add the model to the ensemble
n_models++;
aode_.significance_models[feature] = 1.0;
aode_.significance_models_[feature] = 1.0;
aode_.add_active_parent(feature);
// Compute the prediction
ypred = ExpClf::predict(X_train_);
// Remove the model from the ensemble
aode_.significance_models[feature] = 0.0;
aode_.significance_models_[feature] = 0.0;
aode_.remove_last_parent();
n_models--;
} else {
@@ -113,7 +113,7 @@ namespace platform {
numItemsPack++;
featuresUsed.push_back(feature);
aode_.add_active_parent(feature);
aode_.significance_models[feature] = alpha_t;
aode_.significance_models_[feature] = alpha_t;
n_models++;
VLOG_SCOPE_F(2, "finished: %d numItemsPack: %d n_models: %d featuresUsed: %zu", finished, numItemsPack, n_models, featuresUsed.size());
} // End of the pack
@@ -150,7 +150,7 @@ namespace platform {
VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
aode_.remove_last_parent();
aode_.significance_models[featuresUsed[i]] = 0.0;
aode_.significance_models_[featuresUsed[i]] = 0.0;
n_models--;
}
VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features used.", n_models, featuresUsed.size());

View File

@@ -29,7 +29,7 @@ namespace platform {
COUNTS,
PROBS
};
std::vector<double> significance_models;
std::vector<double> significance_models_;
Xaode() : nFeatures_{ 0 }, statesClass_{ 0 }, matrixState_{ MatrixState::EMPTY } {}
// -------------------------------------------------------
// fit
@@ -43,7 +43,7 @@ namespace platform {
int num_instances = X[0].size();
int n_features_ = X.size();
significance_models.resize(n_features_, (all_parents ? 1.0 : 0.0));
significance_models_.resize(n_features_, (all_parents ? 1.0 : 0.0));
std::vector<int> statesv;
for (int i = 0; i < n_features_; i++) {
if (all_parents) active_parents.push_back(i);