Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "SPnDE.h"
8 :
9 : namespace bayesnet {
10 :
11 456 : SPnDE::SPnDE(std::vector<int> parents) : Classifier(Network()), parents(parents) {}
12 :
13 456 : void SPnDE::buildModel(const torch::Tensor& weights)
14 : {
15 : // 0. Add all nodes to the model
16 456 : addNodes();
17 456 : std::vector<int> attributes;
18 4440 : for (int i = 0; i < static_cast<int>(features.size()); ++i) {
19 3984 : if (std::find(parents.begin(), parents.end(), i) == parents.end()) {
20 3072 : attributes.push_back(i);
21 : }
22 : }
23 : // 1. Add edges from the class node to all other nodes
24 : // 2. Add edges from the parents nodes to all other nodes
25 3528 : for (const auto& attribute : attributes) {
26 3072 : model.addEdge(className, features[attribute]);
27 9216 : for (const auto& root : parents) {
28 :
29 6144 : model.addEdge(features[root], features[attribute]);
30 : }
31 : }
32 456 : }
33 24 : std::vector<std::string> SPnDE::graph(const std::string& name) const
34 : {
35 24 : return model.graph(name);
36 : }
37 :
38 : }
|