Refactor library structure

This commit is contained in:
2024-03-08 22:20:54 +01:00
parent 1231f4522a
commit 635ef22520
56 changed files with 64 additions and 68 deletions

View File

@@ -0,0 +1,25 @@
#include "SPODE.h"
namespace bayesnet {
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
void SPODE::buildModel(const torch::Tensor& weights)
{
// 0. Add all nodes to the model
addNodes();
// 1. Add edges from the class node to all other nodes
// 2. Add edges from the root node to all other nodes
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
model.addEdge(className, features[i]);
if (i != root) {
model.addEdge(features[root], features[i]);
}
}
}
std::vector<std::string> SPODE::graph(const std::string& name) const
{
return model.graph(name);
}
}