2023-07-14 23:05:36 +00:00
|
|
|
#include "AODE.h"
|
|
|
|
|
|
|
|
namespace bayesnet {
|
2024-02-24 17:36:09 +00:00
|
|
|
AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
|
|
|
|
{
|
|
|
|
validHyperparameters = { "predict_voting" };
|
|
|
|
|
|
|
|
}
|
|
|
|
void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
|
|
|
{
|
|
|
|
auto hyperparameters = hyperparameters_;
|
|
|
|
if (hyperparameters.contains("predict_voting")) {
|
|
|
|
predict_voting = hyperparameters["predict_voting"];
|
|
|
|
hyperparameters.erase("predict_voting");
|
|
|
|
}
|
|
|
|
if (!hyperparameters.empty()) {
|
|
|
|
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
|
|
|
}
|
|
|
|
}
|
2023-08-15 13:04:56 +00:00
|
|
|
void AODE::buildModel(const torch::Tensor& weights)
|
2023-07-14 23:05:36 +00:00
|
|
|
{
|
2023-07-14 23:59:30 +00:00
|
|
|
models.clear();
|
2024-02-24 17:36:09 +00:00
|
|
|
significanceModels.clear();
|
2023-07-14 23:05:36 +00:00
|
|
|
for (int i = 0; i < features.size(); ++i) {
|
2023-07-14 23:59:30 +00:00
|
|
|
models.push_back(std::make_unique<SPODE>(i));
|
2023-07-14 23:05:36 +00:00
|
|
|
}
|
2023-08-18 09:50:34 +00:00
|
|
|
n_models = models.size();
|
2023-11-08 17:45:35 +00:00
|
|
|
significanceModels = std::vector<double>(n_models, 1.0);
|
2023-07-14 23:05:36 +00:00
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::string> AODE::graph(const std::string& title) const
|
2023-07-15 23:20:47 +00:00
|
|
|
{
|
|
|
|
return Ensemble::graph(title);
|
|
|
|
}
|
2023-07-14 23:05:36 +00:00
|
|
|
}
|