Don't allow add node nor add edge on fitted networks
This commit is contained in:
parent
26eb58b104
commit
35ca862eca
@ -12,6 +12,7 @@
|
|||||||
#include "bayesnet/utils/bayesnetUtils.h"
|
#include "bayesnet/utils/bayesnetUtils.h"
|
||||||
#include "bayesnet/utils/CountingSemaphore.h"
|
#include "bayesnet/utils/CountingSemaphore.h"
|
||||||
#include <pthread.h>
|
#include <pthread.h>
|
||||||
|
#include <fstream>
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : fitted{ false }, classNumStates{ 0 }
|
Network::Network() : fitted{ false }, classNumStates{ 0 }
|
||||||
{
|
{
|
||||||
@ -40,6 +41,9 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
void Network::addNode(const std::string& name)
|
void Network::addNode(const std::string& name)
|
||||||
{
|
{
|
||||||
|
if (fitted) {
|
||||||
|
throw std::invalid_argument("Cannot add node to a fitted network. Initialize first.");
|
||||||
|
}
|
||||||
if (name == "") {
|
if (name == "") {
|
||||||
throw std::invalid_argument("Node name cannot be empty");
|
throw std::invalid_argument("Node name cannot be empty");
|
||||||
}
|
}
|
||||||
@ -89,6 +93,9 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
void Network::addEdge(const std::string& parent, const std::string& child)
|
void Network::addEdge(const std::string& parent, const std::string& child)
|
||||||
{
|
{
|
||||||
|
if (fitted) {
|
||||||
|
throw std::invalid_argument("Cannot add edge to a fitted network. Initialize first.");
|
||||||
|
}
|
||||||
if (nodes.find(parent) == nodes.end()) {
|
if (nodes.find(parent) == nodes.end()) {
|
||||||
throw std::invalid_argument("Parent node " + parent + " does not exist");
|
throw std::invalid_argument("Parent node " + parent + " does not exist");
|
||||||
}
|
}
|
||||||
@ -227,6 +234,16 @@ namespace bayesnet {
|
|||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
|
// std::fstream file;
|
||||||
|
// file.open("cpt.txt", std::fstream::out | std::fstream::app);
|
||||||
|
// file << std::string(80, '*') << std::endl;
|
||||||
|
// for (const auto& item : graph("Test")) {
|
||||||
|
// file << item << std::endl;
|
||||||
|
// }
|
||||||
|
// file << std::string(80, '-') << std::endl;
|
||||||
|
// file << dump_cpt() << std::endl;
|
||||||
|
// file << std::string(80, '=') << std::endl;
|
||||||
|
// file.close();
|
||||||
fitted = true;
|
fitted = true;
|
||||||
}
|
}
|
||||||
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
|
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
|
||||||
|
@ -104,16 +104,18 @@ namespace bayesnet {
|
|||||||
throw std::logic_error("Feature " + name + " not found in dataset");
|
throw std::logic_error("Feature " + name + " not found in dataset");
|
||||||
}
|
}
|
||||||
int name_index = pos - features.begin();
|
int name_index = pos - features.begin();
|
||||||
|
c10::List<c10::optional<at::Tensor>> coordinates;
|
||||||
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
|
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
|
||||||
c10::List<c10::optional<at::Tensor>> coordinates;
|
coordinates.clear();
|
||||||
coordinates.push_back(dataset.index({ name_index, n_sample }));
|
auto sample = dataset.index({ "...", n_sample });
|
||||||
|
coordinates.push_back(sample[name_index]);
|
||||||
for (auto parent : parents) {
|
for (auto parent : parents) {
|
||||||
pos = find(features.begin(), features.end(), parent->getName());
|
pos = find(features.begin(), features.end(), parent->getName());
|
||||||
if (pos == features.end()) {
|
if (pos == features.end()) {
|
||||||
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
|
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
|
||||||
}
|
}
|
||||||
int parent_index = pos - features.begin();
|
int parent_index = pos - features.begin();
|
||||||
coordinates.push_back(dataset.index({ parent_index, n_sample }));
|
coordinates.push_back(sample[parent_index]);
|
||||||
}
|
}
|
||||||
// Increment the count of the corresponding coordinate
|
// Increment the count of the corresponding coordinate
|
||||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
|
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
|
||||||
@ -134,8 +136,8 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
auto output = std::vector<std::string>();
|
auto output = std::vector<std::string>();
|
||||||
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
||||||
output.push_back(name + " [shape=circle" + suffix + "] \n");
|
output.push_back("\"" + name + "\" [shape=circle" + suffix + "] \n");
|
||||||
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
|
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return "\"" + name + "\" -> \"" + child->getName() + "\""; });
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
}
|
}
|
2
lib/json
2
lib/json
@ -1 +1 @@
|
|||||||
Subproject commit 8c391e04fe4195d8be862c97f38cfe10e2a3472e
|
Subproject commit 960b763ecd144f156d05ec61f577b04107290137
|
2
lib/mdlp
2
lib/mdlp
@ -1 +1 @@
|
|||||||
Subproject commit e36d9af8f939a57266e30ca96e1cf84fc7d107b0
|
Subproject commit 2db60e007d70da876379373c53b6421f281daeac
|
Loading…
Reference in New Issue
Block a user