Complete predict & predict_proba in ensemble

This commit is contained in:
2024-02-24 18:36:09 +01:00
parent 8477698d8d
commit 02e456befb
9 changed files with 104 additions and 101 deletions

View File

@@ -1,10 +1,26 @@
#include "AODE.h"
namespace bayesnet {
AODE::AODE() : Ensemble() {}
AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "predict_voting" };
}
void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
void AODE::buildModel(const torch::Tensor& weights)
{
models.clear();
significanceModels.clear();
for (int i = 0; i < features.size(); ++i) {
models.push_back(std::make_unique<SPODE>(i));
}

View File

@@ -4,12 +4,13 @@
#include "SPODE.h"
namespace bayesnet {
class AODE : public Ensemble {
public:
AODE(bool predict_voting = true);
virtual ~AODE() {};
void setHyperparameters(const nlohmann::json& hyperparameters) override;
std::vector<std::string> graph(const std::string& title = "AODE") const override;
protected:
void buildModel(const torch::Tensor& weights) override;
public:
AODE();
virtual ~AODE() {};
std::vector<std::string> graph(const std::string& title = "AODE") const override;
};
}
#endif

View File

@@ -1,7 +1,22 @@
#include "AODELd.h"
namespace bayesnet {
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
{
validHyperparameters = { "predict_voting" };
}
void AODELd::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
{
checkInput(X_, y_);

View File

@@ -6,15 +6,15 @@
namespace bayesnet {
class AODELd : public Ensemble, public Proposal {
public:
AODELd(bool predict_voting = true);
virtual ~AODELd() = default;
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
protected:
void trainModel(const torch::Tensor& weights) override;
void buildModel(const torch::Tensor& weights) override;
public:
AODELd();
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
virtual ~AODELd() = default;
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
static inline std::string version() { return "0.0.1"; };
};
}
#endif // !AODELD_H

View File

@@ -10,13 +10,14 @@
namespace bayesnet {
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" };
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance", "predict_voting" };
}
void BoostAODE::buildModel(const torch::Tensor& weights)
{
// Models shall be built in trainModel
models.clear();
significanceModels.clear();
n_models = 0;
// Prepare the validation dataset
auto y_ = dataset.index({ -1, "..." });
@@ -72,6 +73,10 @@ namespace bayesnet {
tolerance = hyperparameters["tolerance"];
hyperparameters.erase("tolerance");
}
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (hyperparameters.contains("select_features")) {
auto selectedAlgorithm = hyperparameters["select_features"];
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
@@ -128,8 +133,11 @@ namespace bayesnet {
if (selectFeatures) {
featuresUsed = initializeModels();
}
if (maxModels == 0)
bool resetMaxModels = false;
if (maxModels == 0) {
maxModels = .1 * n > 10 ? .1 * n : n;
resetMaxModels = true; // Flag to unset maxModels
}
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false;
// Variables to control the accuracy finish condition
@@ -211,6 +219,9 @@ namespace bayesnet {
status = WARNING;
}
notes.push_back("Number of models: " + std::to_string(n_models));
if (resetMaxModels) {
maxModels = 0;
}
}
std::vector<std::string> BoostAODE::graph(const std::string& title) const
{

View File

@@ -7,7 +7,7 @@
namespace bayesnet {
class BoostAODE : public Ensemble {
public:
BoostAODE(bool predict_voting = false);
BoostAODE(bool predict_voting = true);
virtual ~BoostAODE() = default;
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;