BayesNet/bayesnet/classifiers/SPODE.cc

31 lines
979 B
C++
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-14 10:59:47 +00:00
#include "SPODE.h"
namespace bayesnet {
2023-07-22 21:07:56 +00:00
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
2023-07-14 10:59:47 +00:00
2023-08-15 13:04:56 +00:00
void SPODE::buildModel(const torch::Tensor& weights)
2023-07-14 10:59:47 +00:00
{
// 0. Add all nodes to the model
addNodes();
// 1. Add edges from the class node to all other nodes
// 2. Add edges from the root node to all other nodes
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
model.addEdge(className, features[i]);
if (i != root) {
model.addEdge(features[root], features[i]);
}
}
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> SPODE::graph(const std::string& name) const
2023-07-15 23:20:47 +00:00
{
return model.graph(name);
}
2023-07-14 10:59:47 +00:00
}