From 31c22898deef1c5b30db8731058fef8fcd933634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Thu, 29 Jun 2023 23:53:33 +0200 Subject: [PATCH] Add cycle detect adding edges --- Network.cc | 71 ++++++++++++++++++++++++++++++++++++++++++++---------- Network.h | 4 +++ Node.cc | 9 ++++++- Node.h | 10 +++++--- main.cc | 11 +++++++++ 5 files changed, 87 insertions(+), 18 deletions(-) diff --git a/Network.cc b/Network.cc index a5c41c2..a81c150 100644 --- a/Network.cc +++ b/Network.cc @@ -6,30 +6,75 @@ namespace bayesnet { delete pair.second; } } - void Network::addNode(std::string name, int numStates) + void Network::addNode(string name, int numStates) { nodes[name] = new Node(name, numStates); + if (root == nullptr) { + root = nodes[name]; + } } - void Network::addEdge(const std::string parent, const std::string child) + void Network::setRoot(string name) + { + if (nodes.find(name) == nodes.end()) { + throw invalid_argument("Node " + name + " does not exist"); + } + root = nodes[name]; + } + Node* Network::getRoot() + { + return root; + } + bool Network::isCyclic(const string& nodeId, unordered_set& visited, unordered_set& recStack) + { + if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet + { + visited.insert(nodeId); + recStack.insert(nodeId); + + for (Node* child : nodes[nodeId]->getChildren()) { + if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack)) + return true; + else if (recStack.find(child->getName()) != recStack.end()) + return true; + } + } + recStack.erase(nodeId); // remove node from recursion stack before function ends + return false; + } + void Network::addEdge(const string parent, const string child) { if (nodes.find(parent) == nodes.end()) { - throw std::invalid_argument("Parent node " + parent + " does not exist"); + throw invalid_argument("Parent node " + parent + " does not exist"); } if (nodes.find(child) == nodes.end()) { - throw std::invalid_argument("Child node " + child + " does not exist"); + throw invalid_argument("Child node " + child + " does not exist"); } + // Temporarily add edge to check for cycles nodes[parent]->addChild(nodes[child]); nodes[child]->addParent(nodes[parent]); + // temporarily add edge + + unordered_set visited; + unordered_set recStack; + + if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle + { + // remove edge + nodes[parent]->removeChild(nodes[child]); + nodes[child]->removeParent(nodes[parent]); + throw invalid_argument("Adding this edge forms a cycle in the graph."); + } + } - std::map& Network::getNodes() + map& Network::getNodes() { return nodes; } - void Network::fit(const std::vector>& dataset, const int smoothing) + void Network::fit(const vector>& dataset, const int smoothing) { - auto jointCounts = [](const std::vector>& data, const std::vector& indices, int numStates) { + auto jointCounts = [](const vector>& data, const vector& indices, int numStates) { int size = indices.size(); - std::vector sizes(size, numStates); + vector sizes(size, numStates); torch::Tensor counts = torch::zeros(sizes, torch::kLong); for (const auto& row : data) { @@ -41,16 +86,16 @@ namespace bayesnet { } return counts; - }; + }; auto marginalCounts = [](const torch::Tensor& jointCounts) { return jointCounts.sum(-1); - }; + }; for (auto& pair : nodes) { Node* node = pair.second; - std::vector indices; + vector indices; for (const auto& parent : node->getParents()) { indices.push_back(nodes[parent->getName()]->getId()); } @@ -67,12 +112,12 @@ namespace bayesnet { } } - torch::Tensor& Network::getCPD(const std::string& key) + torch::Tensor& Network::getCPD(const string& key) { return cpds[key]; } - void Network::setCPD(const std::string& key, const torch::Tensor& cpt) + void Network::setCPD(const string& key, const torch::Tensor& cpt) { cpds[key] = cpt; } diff --git a/Network.h b/Network.h index 94bafdc..8b287eb 100644 --- a/Network.h +++ b/Network.h @@ -8,6 +8,8 @@ namespace bayesnet { private: map nodes; map cpds; // Map from CPD key to CPD tensor + Node* root = nullptr; + bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); public: ~Network(); void addNode(string, int); @@ -16,6 +18,8 @@ namespace bayesnet { void fit(const vector>&, const int); torch::Tensor& getCPD(const string&); void setCPD(const string&, const torch::Tensor&); + void setRoot(string); + Node* getRoot(); }; } #endif \ No newline at end of file diff --git a/Node.cc b/Node.cc index 85b1f1d..b7f72f1 100644 --- a/Node.cc +++ b/Node.cc @@ -17,7 +17,14 @@ namespace bayesnet { { parents.push_back(parent); } - + void Node::removeParent(Node* parent) + { + parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end()); + } + void Node::removeChild(Node* child) + { + children.erase(std::remove(children.begin(), children.end(), child), children.end()); + } void Node::addChild(Node* child) { children.push_back(child); diff --git a/Node.h b/Node.h index a3ab4f3..52a6d4f 100644 --- a/Node.h +++ b/Node.h @@ -15,14 +15,16 @@ namespace bayesnet { int numStates; torch::Tensor cpt; public: - Node(const std::string& name, int numStates); - void addParent(Node* parent); - void addChild(Node* child); + Node(const std::string&, int); + void addParent(Node*); + void addChild(Node*); + void removeParent(Node*); + void removeChild(Node*); string getName() const; vector& getParents(); vector& getChildren(); torch::Tensor& getCPT(); - void setCPT(const torch::Tensor& cpt); + void setCPT(const torch::Tensor&); int getNumStates() const; int getId() const { return id; } string getCPDKey(const Node*) const; diff --git a/main.cc b/main.cc index 8a1d6d1..49bbe75 100644 --- a/main.cc +++ b/main.cc @@ -26,6 +26,13 @@ int main() } cout << "Hello, Bayesian Networks!" << endl; torch::Tensor tensor = torch::eye(3); + cout << "Now I'll add a cycle" << endl; + try { + network.addEdge("petallength", className); + } + catch (invalid_argument& e) { + cout << e.what() << endl; + } cout << tensor << std::endl; cout << "Nodes:" << endl; for (auto [name, item] : network.getNodes()) { @@ -39,5 +46,9 @@ int main() cout << " " << child->getName() << endl; } } + cout << "Root: " << network.getRoot()->getName() << endl; + network.setRoot(className); + cout << "Now Root should be class: " << network.getRoot()->getName() << endl; + return 0; } \ No newline at end of file