BayesNet/bayesnet/classifiers/TAN.cc

60 lines
2.5 KiB
C++
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include "TAN.h"
namespace bayesnet {
TAN::TAN() : Classifier(Network())
{
validHyperparameters = { "parent" };
}
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("parent")) {
parent = hyperparameters["parent"];
hyperparameters.erase("parent");
}
Classifier::setHyperparameters(hyperparameters);
}
2023-08-15 13:04:56 +00:00
void TAN::buildModel(const torch::Tensor& weights)
{
// 0. Add all nodes to the model
addNodes();
2023-07-14 10:59:47 +00:00
// 1. Compute mutual information between each feature and the class and set the root node
// as the highest mutual information with the class
2023-11-08 17:45:35 +00:00
auto mi = std::vector <std::pair<int, float >>();
torch::Tensor class_dataset = dataset.index({ -1, "..." });
2023-07-14 10:59:47 +00:00
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
2023-11-08 17:45:35 +00:00
torch::Tensor feature_dataset = dataset.index({ i, "..." });
2023-08-13 10:56:06 +00:00
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
2023-07-14 10:59:47 +00:00
mi.push_back({ i, mi_value });
}
2023-07-29 18:37:51 +00:00
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
if (root >= static_cast<int>(features.size())) {
throw std::invalid_argument("The parent node is not in the dataset");
}
2023-07-14 10:59:47 +00:00
// 2. Compute mutual information between each feature and the class
2023-08-13 10:56:06 +00:00
auto weights_matrix = metrics.conditionalEdge(weights);
2023-07-14 10:59:47 +00:00
// 3. Compute the maximum spanning tree
2023-08-13 10:56:06 +00:00
auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
2023-07-14 10:59:47 +00:00
// 4. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];
model.addEdge(features[from], features[to]);
}
// 5. Add edges from the class to all features
for (auto feature : features) {
model.addEdge(className, feature);
}
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> TAN::graph(const std::string& title) const
2023-07-15 23:20:47 +00:00
{
return model.graph(title);
}
}