Extract buildModel method to parent class in Boost
This commit is contained in:
parent
54496c68f1
commit
8784a24898
@ -3,6 +3,7 @@
|
|||||||
// SPDX-FileType: SOURCE
|
// SPDX-FileType: SOURCE
|
||||||
// SPDX-License-Identifier: MIT
|
// SPDX-License-Identifier: MIT
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
#include <folding.hpp>
|
||||||
#include "bayesnet/feature_selection/CFS.h"
|
#include "bayesnet/feature_selection/CFS.h"
|
||||||
#include "bayesnet/feature_selection/FCBF.h"
|
#include "bayesnet/feature_selection/FCBF.h"
|
||||||
#include "bayesnet/feature_selection/IWSS.h"
|
#include "bayesnet/feature_selection/IWSS.h"
|
||||||
@ -67,6 +68,37 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
Classifier::setHyperparameters(hyperparameters);
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
}
|
}
|
||||||
|
void Boost::buildModel(const torch::Tensor& weights)
|
||||||
|
{
|
||||||
|
// Models shall be built in trainModel
|
||||||
|
models.clear();
|
||||||
|
significanceModels.clear();
|
||||||
|
n_models = 0;
|
||||||
|
// Prepare the validation dataset
|
||||||
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
|
if (convergence) {
|
||||||
|
// Prepare train & validation sets from train data
|
||||||
|
auto fold = folding::StratifiedKFold(5, y_, 271);
|
||||||
|
auto [train, test] = fold.getFold(0);
|
||||||
|
auto train_t = torch::tensor(train);
|
||||||
|
auto test_t = torch::tensor(test);
|
||||||
|
// Get train and validation sets
|
||||||
|
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
||||||
|
y_train = dataset.index({ -1, train_t });
|
||||||
|
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
||||||
|
y_test = dataset.index({ -1, test_t });
|
||||||
|
dataset = X_train;
|
||||||
|
m = X_train.size(1);
|
||||||
|
auto n_classes = states.at(className).size();
|
||||||
|
// Build dataset with train data
|
||||||
|
buildDataset(y_train);
|
||||||
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
|
} else {
|
||||||
|
// Use all data to train
|
||||||
|
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||||
|
y_train = y_;
|
||||||
|
}
|
||||||
|
}
|
||||||
std::vector<int> Boost::featureSelection(torch::Tensor& weights_)
|
std::vector<int> Boost::featureSelection(torch::Tensor& weights_)
|
||||||
{
|
{
|
||||||
int maxFeatures = 0;
|
int maxFeatures = 0;
|
||||||
|
@ -31,6 +31,7 @@ namespace bayesnet {
|
|||||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
protected:
|
protected:
|
||||||
std::vector<int> featureSelection(torch::Tensor& weights_);
|
std::vector<int> featureSelection(torch::Tensor& weights_);
|
||||||
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights);
|
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights);
|
||||||
std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
|
std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
|
||||||
torch::Tensor X_train, y_train, X_test, y_test;
|
torch::Tensor X_train, y_train, X_test, y_test;
|
||||||
|
@ -18,38 +18,6 @@ namespace bayesnet {
|
|||||||
|
|
||||||
BoostA2DE::BoostA2DE(bool predict_voting) : Boost(predict_voting)
|
BoostA2DE::BoostA2DE(bool predict_voting) : Boost(predict_voting)
|
||||||
{
|
{
|
||||||
}
|
|
||||||
void BoostA2DE::buildModel(const torch::Tensor& weights)
|
|
||||||
{
|
|
||||||
// Models shall be built in trainModel
|
|
||||||
models.clear();
|
|
||||||
significanceModels.clear();
|
|
||||||
n_models = 0;
|
|
||||||
// Prepare the validation dataset
|
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
|
||||||
if (convergence) {
|
|
||||||
// Prepare train & validation sets from train data
|
|
||||||
auto fold = folding::StratifiedKFold(5, y_, 271);
|
|
||||||
auto [train, test] = fold.getFold(0);
|
|
||||||
auto train_t = torch::tensor(train);
|
|
||||||
auto test_t = torch::tensor(test);
|
|
||||||
// Get train and validation sets
|
|
||||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
|
||||||
y_train = dataset.index({ -1, train_t });
|
|
||||||
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
|
||||||
y_test = dataset.index({ -1, test_t });
|
|
||||||
dataset = X_train;
|
|
||||||
m = X_train.size(1);
|
|
||||||
auto n_classes = states.at(className).size();
|
|
||||||
// Build dataset with train data
|
|
||||||
buildDataset(y_train);
|
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
|
||||||
} else {
|
|
||||||
// Use all data to train
|
|
||||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
|
||||||
y_train = y_;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
void BoostA2DE::trainModel(const torch::Tensor& weights)
|
void BoostA2DE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
@ -17,7 +17,6 @@ namespace bayesnet {
|
|||||||
virtual ~BoostA2DE() = default;
|
virtual ~BoostA2DE() = default;
|
||||||
std::vector<std::string> graph(const std::string& title = "BoostA2DE") const override;
|
std::vector<std::string> graph(const std::string& title = "BoostA2DE") const override;
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,11 @@
|
|||||||
// SPDX-License-Identifier: MIT
|
// SPDX-License-Identifier: MIT
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
|
||||||
|
#include <random>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <folding.hpp>
|
|
||||||
#include "BoostAODE.h"
|
#include "BoostAODE.h"
|
||||||
#include "lib/log/loguru.cpp"
|
#include "lib/log/loguru.cpp"
|
||||||
|
|
||||||
@ -17,37 +17,7 @@ namespace bayesnet {
|
|||||||
BoostAODE::BoostAODE(bool predict_voting) : Boost(predict_voting)
|
BoostAODE::BoostAODE(bool predict_voting) : Boost(predict_voting)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
|
||||||
{
|
|
||||||
// Models shall be built in trainModel
|
|
||||||
models.clear();
|
|
||||||
significanceModels.clear();
|
|
||||||
n_models = 0;
|
|
||||||
// Prepare the validation dataset
|
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
|
||||||
if (convergence) {
|
|
||||||
// Prepare train & validation sets from train data
|
|
||||||
auto fold = folding::StratifiedKFold(5, y_, 271);
|
|
||||||
auto [train, test] = fold.getFold(0);
|
|
||||||
auto train_t = torch::tensor(train);
|
|
||||||
auto test_t = torch::tensor(test);
|
|
||||||
// Get train and validation sets
|
|
||||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
|
||||||
y_train = dataset.index({ -1, train_t });
|
|
||||||
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
|
||||||
y_test = dataset.index({ -1, test_t });
|
|
||||||
dataset = X_train;
|
|
||||||
m = X_train.size(1);
|
|
||||||
auto n_classes = states.at(className).size();
|
|
||||||
// Build dataset with train data
|
|
||||||
buildDataset(y_train);
|
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
|
||||||
} else {
|
|
||||||
// Use all data to train
|
|
||||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
|
||||||
y_train = y_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::vector<int> BoostAODE::initializeModels()
|
std::vector<int> BoostAODE::initializeModels()
|
||||||
{
|
{
|
||||||
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
|
@ -17,7 +17,6 @@ namespace bayesnet {
|
|||||||
virtual ~BoostAODE() = default;
|
virtual ~BoostAODE() = default;
|
||||||
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
private:
|
private:
|
||||||
std::vector<int> initializeModels();
|
std::vector<int> initializeModels();
|
||||||
|
Loading…
Reference in New Issue
Block a user