Begin testing ensemble predict_proba

This commit is contained in:
2024-02-22 18:44:40 +01:00
parent 443e5cc882
commit 3116eaa763
4 changed files with 67 additions and 23 deletions

View File

@@ -8,7 +8,7 @@
#include "folding.hpp"
namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble(false)
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" };

View File

@@ -7,7 +7,7 @@
namespace bayesnet {
class BoostAODE : public Ensemble {
public:
BoostAODE();
BoostAODE(bool predict_voting = false);
virtual ~BoostAODE() = default;
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;

View File

@@ -36,7 +36,6 @@ namespace bayesnet {
std::vector<double> significanceModels;
void trainModel(const torch::Tensor& weights) override;
std::vector<int> voting(torch::Tensor& y_pred);
private:
bool predict_voting;
};
}