From 71b88e2c65d5cc7d91cd8a8bc91030e773c5fb9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 6 Jul 2023 13:12:41 +0200 Subject: [PATCH] Refactor to remove root node of network --- sample/main.cc | 3 --- src/Network.cc | 25 ++++--------------------- src/Network.h | 3 --- 3 files changed, 4 insertions(+), 27 deletions(-) diff --git a/sample/main.cc b/sample/main.cc index 00aeac5..5d11d9a 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -107,9 +107,6 @@ void showNodesInfo(bayesnet::Network& network, string className) } cout << endl; } - cout << "Root: " << network.getRoot()->getName() << endl; - network.setRoot(className); - cout << "Now Root should be class: " << network.getRoot()->getName() << endl; } void showCPDS(bayesnet::Network& network) { diff --git a/src/Network.cc b/src/Network.cc index b29465d..130d317 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -2,10 +2,10 @@ #include #include "Network.h" namespace bayesnet { - Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(0.8) {} - Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} - Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} - Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) + Network::Network() : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(0.8) {} + Network::Network(float maxT) : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} + Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} + Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) { for (auto& pair : other.nodes) { nodes[pair.first] = new Node(*pair.second); @@ -29,9 +29,6 @@ namespace bayesnet { return; } nodes[name] = new Node(name, numStates); - if (root == nullptr) { - root = nodes[name]; - } } vector Network::getFeatures() { @@ -45,17 +42,6 @@ namespace bayesnet { { return className; } - void Network::setRoot(string name) - { - if (nodes.find(name) == nodes.end()) { - throw invalid_argument("Node " + name + " does not exist"); - } - root = nodes[name]; - } - Node* Network::getRoot() - { - return root; - } bool Network::isCyclic(const string& nodeId, unordered_set& visited, unordered_set& recStack) { if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet @@ -227,18 +213,15 @@ namespace bayesnet { vector result(classNumStates, 0.0); vector threads; mutex mtx; - for (int i = 0; i < classNumStates; ++i) { threads.emplace_back([this, &result, &evidence, i, &mtx]() { auto completeEvidence = map(evidence); completeEvidence[getClassName()] = i; double factor = computeFactor(completeEvidence); - lock_guard lock(mtx); result[i] = factor; }); } - for (auto& thread : threads) { thread.join(); } diff --git a/src/Network.h b/src/Network.h index 9ba531c..2eb2b82 100644 --- a/src/Network.h +++ b/src/Network.h @@ -10,7 +10,6 @@ namespace bayesnet { private: map nodes; map> dataset; - Node* root; float maxThreads; int classNumStates; vector features; @@ -34,8 +33,6 @@ namespace bayesnet { int getClassNumStates(); string getClassName(); void fit(const vector>&, const vector&, const vector&, const string&); - void setRoot(string); - Node* getRoot(); vector predict(const vector>&); vector> predict_proba(const vector>&); double score(const vector>&, const vector&);