diff --git a/bayesnet/network/Network.cc b/bayesnet/network/Network.cc index f26e19d..b6482db 100644 --- a/bayesnet/network/Network.cc +++ b/bayesnet/network/Network.cc @@ -12,6 +12,7 @@ #include "bayesnet/utils/bayesnetUtils.h" #include "bayesnet/utils/CountingSemaphore.h" #include +#include namespace bayesnet { Network::Network() : fitted{ false }, classNumStates{ 0 } { @@ -40,6 +41,9 @@ namespace bayesnet { } void Network::addNode(const std::string& name) { + if (fitted) { + throw std::invalid_argument("Cannot add node to a fitted network. Initialize first."); + } if (name == "") { throw std::invalid_argument("Node name cannot be empty"); } @@ -89,6 +93,9 @@ namespace bayesnet { } void Network::addEdge(const std::string& parent, const std::string& child) { + if (fitted) { + throw std::invalid_argument("Cannot add edge to a fitted network. Initialize first."); + } if (nodes.find(parent) == nodes.end()) { throw std::invalid_argument("Parent node " + parent + " does not exist"); } @@ -227,6 +234,16 @@ namespace bayesnet { for (auto& thread : threads) { thread.join(); } + // std::fstream file; + // file.open("cpt.txt", std::fstream::out | std::fstream::app); + // file << std::string(80, '*') << std::endl; + // for (const auto& item : graph("Test")) { + // file << item << std::endl; + // } + // file << std::string(80, '-') << std::endl; + // file << dump_cpt() << std::endl; + // file << std::string(80, '=') << std::endl; + // file.close(); fitted = true; } torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) diff --git a/bayesnet/network/Node.cc b/bayesnet/network/Node.cc index 28eba72..44fc900 100644 --- a/bayesnet/network/Node.cc +++ b/bayesnet/network/Node.cc @@ -104,16 +104,18 @@ namespace bayesnet { throw std::logic_error("Feature " + name + " not found in dataset"); } int name_index = pos - features.begin(); + c10::List> coordinates; for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) { - c10::List> coordinates; - coordinates.push_back(dataset.index({ name_index, n_sample })); + coordinates.clear(); + auto sample = dataset.index({ "...", n_sample }); + coordinates.push_back(sample[name_index]); for (auto parent : parents) { pos = find(features.begin(), features.end(), parent->getName()); if (pos == features.end()) { throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset"); } int parent_index = pos - features.begin(); - coordinates.push_back(dataset.index({ parent_index, n_sample })); + coordinates.push_back(sample[parent_index]); } // Increment the count of the corresponding coordinate cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item()); @@ -134,8 +136,8 @@ namespace bayesnet { { auto output = std::vector(); auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : ""; - output.push_back(name + " [shape=circle" + suffix + "] \n"); - transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); }); + output.push_back("\"" + name + "\" [shape=circle" + suffix + "] \n"); + transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return "\"" + name + "\" -> \"" + child->getName() + "\""; }); return output; } } \ No newline at end of file diff --git a/lib/json b/lib/json index 8c391e0..960b763 160000 --- a/lib/json +++ b/lib/json @@ -1 +1 @@ -Subproject commit 8c391e04fe4195d8be862c97f38cfe10e2a3472e +Subproject commit 960b763ecd144f156d05ec61f577b04107290137 diff --git a/lib/mdlp b/lib/mdlp index e36d9af..2db60e0 160000 --- a/lib/mdlp +++ b/lib/mdlp @@ -1 +1 @@ -Subproject commit e36d9af8f939a57266e30ca96e1cf84fc7d107b0 +Subproject commit 2db60e007d70da876379373c53b6421f281daeac