BayesNet/bayesnet/classifiers/SPODE.cc

46 lines
1.5 KiB
C++
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-14 10:59:47 +00:00
#include "SPODE.h"
namespace bayesnet {
SPODE::SPODE(int root) : Classifier(Network()), root(root)
{
validHyperparameters = { "parent" };
}
2023-07-14 10:59:47 +00:00
void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("parent")) {
root = hyperparameters["parent"];
hyperparameters.erase("parent");
}
Classifier::setHyperparameters(hyperparameters);
}
2023-08-15 13:04:56 +00:00
void SPODE::buildModel(const torch::Tensor& weights)
2023-07-14 10:59:47 +00:00
{
// 0. Add all nodes to the model
addNodes();
// 1. Add edges from the class node to all other nodes
// 2. Add edges from the root node to all other nodes
if (root >= static_cast<int>(features.size())) {
throw std::invalid_argument("The parent node is not in the dataset");
}
2023-07-14 10:59:47 +00:00
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
model.addEdge(className, features[i]);
if (i != root) {
model.addEdge(features[root], features[i]);
}
}
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> SPODE::graph(const std::string& name) const
2023-07-15 23:20:47 +00:00
{
return model.graph(name);
}
2023-07-14 10:59:47 +00:00
}