Extract buildModel method to parent class in Boost

This commit is contained in:
2024-05-15 20:00:44 +02:00
parent 54496c68f1
commit 8784a24898
6 changed files with 35 additions and 66 deletions

View File

@@ -31,6 +31,7 @@ namespace bayesnet {
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
protected:
std::vector<int> featureSelection(torch::Tensor& weights_);
void buildModel(const torch::Tensor& weights) override;
std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights);
std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
torch::Tensor X_train, y_train, X_test, y_test;