diff --git a/bayesnet/ensembles/Boost.cc b/bayesnet/ensembles/Boost.cc index 73d4adb..a582811 100644 --- a/bayesnet/ensembles/Boost.cc +++ b/bayesnet/ensembles/Boost.cc @@ -3,6 +3,7 @@ // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** +#include #include "bayesnet/feature_selection/CFS.h" #include "bayesnet/feature_selection/FCBF.h" #include "bayesnet/feature_selection/IWSS.h" @@ -67,6 +68,37 @@ namespace bayesnet { } Classifier::setHyperparameters(hyperparameters); } + void Boost::buildModel(const torch::Tensor& weights) + { + // Models shall be built in trainModel + models.clear(); + significanceModels.clear(); + n_models = 0; + // Prepare the validation dataset + auto y_ = dataset.index({ -1, "..." }); + if (convergence) { + // Prepare train & validation sets from train data + auto fold = folding::StratifiedKFold(5, y_, 271); + auto [train, test] = fold.getFold(0); + auto train_t = torch::tensor(train); + auto test_t = torch::tensor(test); + // Get train and validation sets + X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); + y_train = dataset.index({ -1, train_t }); + X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); + y_test = dataset.index({ -1, test_t }); + dataset = X_train; + m = X_train.size(1); + auto n_classes = states.at(className).size(); + // Build dataset with train data + buildDataset(y_train); + metrics = Metrics(dataset, features, className, n_classes); + } else { + // Use all data to train + X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); + y_train = y_; + } + } std::vector Boost::featureSelection(torch::Tensor& weights_) { int maxFeatures = 0; diff --git a/bayesnet/ensembles/Boost.h b/bayesnet/ensembles/Boost.h index d00728b..2594bcb 100644 --- a/bayesnet/ensembles/Boost.h +++ b/bayesnet/ensembles/Boost.h @@ -31,6 +31,7 @@ namespace bayesnet { void setHyperparameters(const nlohmann::json& hyperparameters_) override; protected: std::vector featureSelection(torch::Tensor& weights_); + void buildModel(const torch::Tensor& weights) override; std::tuple update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights); std::tuple update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights); torch::Tensor X_train, y_train, X_test, y_test; diff --git a/bayesnet/ensembles/BoostA2DE.cc b/bayesnet/ensembles/BoostA2DE.cc index 9ad85e9..77f2971 100644 --- a/bayesnet/ensembles/BoostA2DE.cc +++ b/bayesnet/ensembles/BoostA2DE.cc @@ -18,38 +18,6 @@ namespace bayesnet { BoostA2DE::BoostA2DE(bool predict_voting) : Boost(predict_voting) { - } - void BoostA2DE::buildModel(const torch::Tensor& weights) - { - // Models shall be built in trainModel - models.clear(); - significanceModels.clear(); - n_models = 0; - // Prepare the validation dataset - auto y_ = dataset.index({ -1, "..." }); - if (convergence) { - // Prepare train & validation sets from train data - auto fold = folding::StratifiedKFold(5, y_, 271); - auto [train, test] = fold.getFold(0); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - // Get train and validation sets - X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); - y_train = dataset.index({ -1, train_t }); - X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); - y_test = dataset.index({ -1, test_t }); - dataset = X_train; - m = X_train.size(1); - auto n_classes = states.at(className).size(); - // Build dataset with train data - buildDataset(y_train); - metrics = Metrics(dataset, features, className, n_classes); - } else { - // Use all data to train - X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); - y_train = y_; - } - } void BoostA2DE::trainModel(const torch::Tensor& weights) { diff --git a/bayesnet/ensembles/BoostA2DE.h b/bayesnet/ensembles/BoostA2DE.h index c81531e..34cfe24 100644 --- a/bayesnet/ensembles/BoostA2DE.h +++ b/bayesnet/ensembles/BoostA2DE.h @@ -17,7 +17,6 @@ namespace bayesnet { virtual ~BoostA2DE() = default; std::vector graph(const std::string& title = "BoostA2DE") const override; protected: - void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override; }; } diff --git a/bayesnet/ensembles/BoostAODE.cc b/bayesnet/ensembles/BoostAODE.cc index a1a748a..e4ea907 100644 --- a/bayesnet/ensembles/BoostAODE.cc +++ b/bayesnet/ensembles/BoostAODE.cc @@ -4,11 +4,11 @@ // SPDX-License-Identifier: MIT // *************************************************************** +#include #include #include #include #include -#include #include "BoostAODE.h" #include "lib/log/loguru.cpp" @@ -17,37 +17,7 @@ namespace bayesnet { BoostAODE::BoostAODE(bool predict_voting) : Boost(predict_voting) { } - void BoostAODE::buildModel(const torch::Tensor& weights) - { - // Models shall be built in trainModel - models.clear(); - significanceModels.clear(); - n_models = 0; - // Prepare the validation dataset - auto y_ = dataset.index({ -1, "..." }); - if (convergence) { - // Prepare train & validation sets from train data - auto fold = folding::StratifiedKFold(5, y_, 271); - auto [train, test] = fold.getFold(0); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - // Get train and validation sets - X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); - y_train = dataset.index({ -1, train_t }); - X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); - y_test = dataset.index({ -1, test_t }); - dataset = X_train; - m = X_train.size(1); - auto n_classes = states.at(className).size(); - // Build dataset with train data - buildDataset(y_train); - metrics = Metrics(dataset, features, className, n_classes); - } else { - // Use all data to train - X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." }); - y_train = y_; - } - } + std::vector BoostAODE::initializeModels() { torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); diff --git a/bayesnet/ensembles/BoostAODE.h b/bayesnet/ensembles/BoostAODE.h index d6da098..8c613a9 100644 --- a/bayesnet/ensembles/BoostAODE.h +++ b/bayesnet/ensembles/BoostAODE.h @@ -17,7 +17,6 @@ namespace bayesnet { virtual ~BoostAODE() = default; std::vector graph(const std::string& title = "BoostAODE") const override; protected: - void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override; private: std::vector initializeModels();