Extract buildModel method to parent class in Boost

This commit is contained in:
2024-05-15 20:00:44 +02:00
parent 54496c68f1
commit 8784a24898
6 changed files with 35 additions and 66 deletions

View File

@@ -4,11 +4,11 @@
// SPDX-License-Identifier: MIT
// ***************************************************************
#include <random>
#include <set>
#include <functional>
#include <limits.h>
#include <tuple>
#include <folding.hpp>
#include "BoostAODE.h"
#include "lib/log/loguru.cpp"
@@ -17,37 +17,7 @@ namespace bayesnet {
BoostAODE::BoostAODE(bool predict_voting) : Boost(predict_voting)
{
}
void BoostAODE::buildModel(const torch::Tensor& weights)
{
// Models shall be built in trainModel
models.clear();
significanceModels.clear();
n_models = 0;
// Prepare the validation dataset
auto y_ = dataset.index({ -1, "..." });
if (convergence) {
// Prepare train & validation sets from train data
auto fold = folding::StratifiedKFold(5, y_, 271);
auto [train, test] = fold.getFold(0);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
// Get train and validation sets
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
y_train = dataset.index({ -1, train_t });
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
y_test = dataset.index({ -1, test_t });
dataset = X_train;
m = X_train.size(1);
auto n_classes = states.at(className).size();
// Build dataset with train data
buildDataset(y_train);
metrics = Metrics(dataset, features, className, n_classes);
} else {
// Use all data to train
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_;
}
}
std::vector<int> BoostAODE::initializeModels()
{
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);