8.9 KiB
8.9 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include "TAN.h" 8 : 9 : namespace bayesnet { 10 47 : TAN::TAN() : Classifier(Network()) {} 11 : 12 13 : void TAN::buildModel(const torch::Tensor& weights) 13 : { 14 : // 0. Add all nodes to the model 15 13 : addNodes(); 16 : // 1. Compute mutual information between each feature and the class and set the root node 17 : // as the highest mutual information with the class 18 13 : auto mi = std::vector <std::pair<int, float >>(); 19 39 : torch::Tensor class_dataset = dataset.index({ -1, "..." }); 20 89 : for (int i = 0; i < static_cast<int>(features.size()); ++i) { 21 228 : torch::Tensor feature_dataset = dataset.index({ i, "..." }); 22 76 : auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights); 23 76 : mi.push_back({ i, mi_value }); 24 76 : } 25 175 : sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;}); 26 13 : auto root = mi[mi.size() - 1].first; 27 : // 2. Compute mutual information between each feature and the class 28 13 : auto weights_matrix = metrics.conditionalEdge(weights); 29 : // 3. Compute the maximum spanning tree 30 13 : auto mst = metrics.maximumSpanningTree(features, weights_matrix, root); 31 : // 4. Add edges from the maximum spanning tree to the model 32 76 : for (auto i = 0; i < mst.size(); ++i) { 33 63 : auto [from, to] = mst[i]; 34 63 : model.addEdge(features[from], features[to]); 35 : } 36 : // 5. Add edges from the class to all features 37 89 : for (auto feature : features) { 38 76 : model.addEdge(className, feature); 39 76 : } 40 102 : } 41 2 : std::vector<std::string> TAN::graph(const std::string& title) const 42 : { 43 2 : return model.graph(title); 44 : } 45 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>