Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "TAN.h"
8 :
9 : namespace bayesnet {
10 282 : TAN::TAN() : Classifier(Network()) {}
11 :
12 78 : void TAN::buildModel(const torch::Tensor& weights)
13 : {
14 : // 0. Add all nodes to the model
15 78 : addNodes();
16 : // 1. Compute mutual information between each feature and the class and set the root node
17 : // as the highest mutual information with the class
18 78 : auto mi = std::vector <std::pair<int, float >>();
19 234 : torch::Tensor class_dataset = dataset.index({ -1, "..." });
20 534 : for (int i = 0; i < static_cast<int>(features.size()); ++i) {
21 1368 : torch::Tensor feature_dataset = dataset.index({ i, "..." });
22 456 : auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
23 456 : mi.push_back({ i, mi_value });
24 456 : }
25 1050 : sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
26 78 : auto root = mi[mi.size() - 1].first;
27 : // 2. Compute mutual information between each feature and the class
28 78 : auto weights_matrix = metrics.conditionalEdge(weights);
29 : // 3. Compute the maximum spanning tree
30 78 : auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
31 : // 4. Add edges from the maximum spanning tree to the model
32 456 : for (auto i = 0; i < mst.size(); ++i) {
33 378 : auto [from, to] = mst[i];
34 378 : model.addEdge(features[from], features[to]);
35 : }
36 : // 5. Add edges from the class to all features
37 534 : for (auto feature : features) {
38 456 : model.addEdge(className, feature);
39 456 : }
40 612 : }
41 12 : std::vector<std::string> TAN::graph(const std::string& title) const
42 : {
43 12 : return model.graph(title);
44 : }
45 : }
|