Add parent hyperparameter to TAN & SPODE
This commit is contained in:
@@ -7,8 +7,20 @@
|
||||
#include "TAN.h"
|
||||
|
||||
namespace bayesnet {
|
||||
TAN::TAN() : Classifier(Network()) {}
|
||||
TAN::TAN() : Classifier(Network())
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
parent = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void TAN::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
@@ -23,7 +35,10 @@ namespace bayesnet {
|
||||
mi.push_back({ i, mi_value });
|
||||
}
|
||||
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
|
||||
auto root = mi[mi.size() - 1].first;
|
||||
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
|
||||
if (root >= static_cast<int>(features.size())) {
|
||||
throw std::invalid_argument("The parent node is not in the dataset");
|
||||
}
|
||||
// 2. Compute mutual information between each feature and the class
|
||||
auto weights_matrix = metrics.conditionalEdge(weights);
|
||||
// 3. Compute the maximum spanning tree
|
||||
|
Reference in New Issue
Block a user