Continue refactoring
This commit is contained in:
@@ -13,6 +13,22 @@ namespace platform {
|
|||||||
validHyperparameters = {};
|
validHyperparameters = {};
|
||||||
}
|
}
|
||||||
//
|
//
|
||||||
|
// Parents
|
||||||
|
//
|
||||||
|
void ExpClf::add_active_parents(const std::vector<int>& active_parents)
|
||||||
|
{
|
||||||
|
for (const auto& parent : active_parents)
|
||||||
|
aode_.add_active_parent(parent);
|
||||||
|
}
|
||||||
|
void ExpClf::add_active_parent(int parent)
|
||||||
|
{
|
||||||
|
aode_.add_active_parent(parent);
|
||||||
|
}
|
||||||
|
void ExpClf::remove_last_parent()
|
||||||
|
{
|
||||||
|
aode_.remove_last_parent();
|
||||||
|
}
|
||||||
|
//
|
||||||
// Predict
|
// Predict
|
||||||
//
|
//
|
||||||
std::vector<int> ExpClf::predict_spode(std::vector<std::vector<int>>& test_data, int parent)
|
std::vector<int> ExpClf::predict_spode(std::vector<std::vector<int>>& test_data, int parent)
|
||||||
|
@@ -39,9 +39,9 @@ namespace platform {
|
|||||||
bayesnet::status_t getStatus() const override { return status; }
|
bayesnet::status_t getStatus() const override { return status; }
|
||||||
std::vector<std::string> getNotes() const override { return notes; }
|
std::vector<std::string> getNotes() const override { return notes; }
|
||||||
std::vector<std::string> graph(const std::string& title = "") const override { return {}; }
|
std::vector<std::string> graph(const std::string& title = "") const override { return {}; }
|
||||||
void set_active_parents(const std::vector<int>& active_parents) { for (const auto& parent : active_parents) aode_.add_active_parent(parent); }
|
void add_active_parents(const std::vector<int>& active_parents);
|
||||||
void add_active_parent(int parent) { aode_.add_active_parent(parent); }
|
void add_active_parent(int parent);
|
||||||
void remove_last_parent() { aode_.remove_last_parent(); }
|
void remove_last_parent();
|
||||||
protected:
|
protected:
|
||||||
bool debug = false;
|
bool debug = false;
|
||||||
Xaode aode_;
|
Xaode aode_;
|
||||||
|
@@ -13,7 +13,7 @@ namespace platform {
|
|||||||
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
|
auto X = TensorUtils::to_matrix(dataset.slice(0, 0, dataset.size(0) - 1));
|
||||||
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
|
auto y = TensorUtils::to_vector<int>(dataset.index({ -1, "..." }));
|
||||||
int num_instances = X[0].size();
|
int num_instances = X[0].size();
|
||||||
weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
weights_ = weights;
|
||||||
normalize_weights(num_instances);
|
normalize_weights(num_instances);
|
||||||
aode_.fit(X, y, features, className, states, weights_, true);
|
aode_.fit(X, y, features, className, states, weights_, true);
|
||||||
}
|
}
|
||||||
|
@@ -43,13 +43,13 @@ namespace platform {
|
|||||||
n_models = 0;
|
n_models = 0;
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = featureSelection(weights_);
|
featuresUsed = featureSelection(weights_);
|
||||||
set_active_parents(featuresUsed);
|
add_active_parents(featuresUsed);
|
||||||
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
|
notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
|
||||||
auto ypred = ExpClf::predict(X_train);
|
auto ypred = ExpClf::predict(X_train);
|
||||||
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
|
||||||
// Update significance of the models
|
// Update significance of the models
|
||||||
for (const auto& parent : featuresUsed) {
|
for (const auto& parent : featuresUsed) {
|
||||||
aode_.significance_models[parent] = alpha_t;
|
aode_.significance_models_[parent] = alpha_t;
|
||||||
}
|
}
|
||||||
n_models = featuresUsed.size();
|
n_models = featuresUsed.size();
|
||||||
VLOG_SCOPE_F(1, "SelectFeatures. alpha_t: %f n_models: %d", alpha_t, n_models);
|
VLOG_SCOPE_F(1, "SelectFeatures. alpha_t: %f n_models: %d", alpha_t, n_models);
|
||||||
@@ -95,12 +95,12 @@ namespace platform {
|
|||||||
//
|
//
|
||||||
// Add the model to the ensemble
|
// Add the model to the ensemble
|
||||||
n_models++;
|
n_models++;
|
||||||
aode_.significance_models[feature] = 1.0;
|
aode_.significance_models_[feature] = 1.0;
|
||||||
aode_.add_active_parent(feature);
|
aode_.add_active_parent(feature);
|
||||||
// Compute the prediction
|
// Compute the prediction
|
||||||
ypred = ExpClf::predict(X_train_);
|
ypred = ExpClf::predict(X_train_);
|
||||||
// Remove the model from the ensemble
|
// Remove the model from the ensemble
|
||||||
aode_.significance_models[feature] = 0.0;
|
aode_.significance_models_[feature] = 0.0;
|
||||||
aode_.remove_last_parent();
|
aode_.remove_last_parent();
|
||||||
n_models--;
|
n_models--;
|
||||||
} else {
|
} else {
|
||||||
@@ -113,7 +113,7 @@ namespace platform {
|
|||||||
numItemsPack++;
|
numItemsPack++;
|
||||||
featuresUsed.push_back(feature);
|
featuresUsed.push_back(feature);
|
||||||
aode_.add_active_parent(feature);
|
aode_.add_active_parent(feature);
|
||||||
aode_.significance_models[feature] = alpha_t;
|
aode_.significance_models_[feature] = alpha_t;
|
||||||
n_models++;
|
n_models++;
|
||||||
VLOG_SCOPE_F(2, "finished: %d numItemsPack: %d n_models: %d featuresUsed: %zu", finished, numItemsPack, n_models, featuresUsed.size());
|
VLOG_SCOPE_F(2, "finished: %d numItemsPack: %d n_models: %d featuresUsed: %zu", finished, numItemsPack, n_models, featuresUsed.size());
|
||||||
} // End of the pack
|
} // End of the pack
|
||||||
@@ -150,7 +150,7 @@ namespace platform {
|
|||||||
VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
|
VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
|
||||||
for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
|
for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
|
||||||
aode_.remove_last_parent();
|
aode_.remove_last_parent();
|
||||||
aode_.significance_models[featuresUsed[i]] = 0.0;
|
aode_.significance_models_[featuresUsed[i]] = 0.0;
|
||||||
n_models--;
|
n_models--;
|
||||||
}
|
}
|
||||||
VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features used.", n_models, featuresUsed.size());
|
VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features used.", n_models, featuresUsed.size());
|
||||||
|
@@ -29,7 +29,7 @@ namespace platform {
|
|||||||
COUNTS,
|
COUNTS,
|
||||||
PROBS
|
PROBS
|
||||||
};
|
};
|
||||||
std::vector<double> significance_models;
|
std::vector<double> significance_models_;
|
||||||
Xaode() : nFeatures_{ 0 }, statesClass_{ 0 }, matrixState_{ MatrixState::EMPTY } {}
|
Xaode() : nFeatures_{ 0 }, statesClass_{ 0 }, matrixState_{ MatrixState::EMPTY } {}
|
||||||
// -------------------------------------------------------
|
// -------------------------------------------------------
|
||||||
// fit
|
// fit
|
||||||
@@ -43,7 +43,7 @@ namespace platform {
|
|||||||
int num_instances = X[0].size();
|
int num_instances = X[0].size();
|
||||||
int n_features_ = X.size();
|
int n_features_ = X.size();
|
||||||
|
|
||||||
significance_models.resize(n_features_, (all_parents ? 1.0 : 0.0));
|
significance_models_.resize(n_features_, (all_parents ? 1.0 : 0.0));
|
||||||
std::vector<int> statesv;
|
std::vector<int> statesv;
|
||||||
for (int i = 0; i < n_features_; i++) {
|
for (int i = 0; i < n_features_; i++) {
|
||||||
if (all_parents) active_parents.push_back(i);
|
if (all_parents) active_parents.push_back(i);
|
||||||
|
Reference in New Issue
Block a user