Refactor to remove root node of network

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 13:12:41 +02:00
parent 0b33c6c04a
commit 71b88e2c65
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 4 additions and 27 deletions

View File

@ -107,9 +107,6 @@ void showNodesInfo(bayesnet::Network& network, string className)
} }
cout << endl; cout << endl;
} }
cout << "Root: " << network.getRoot()->getName() << endl;
network.setRoot(className);
cout << "Now Root should be class: " << network.getRoot()->getName() << endl;
} }
void showCPDS(bayesnet::Network& network) void showCPDS(bayesnet::Network& network)
{ {

View File

@ -2,10 +2,10 @@
#include <mutex> #include <mutex>
#include "Network.h" #include "Network.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {} Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
{ {
for (auto& pair : other.nodes) { for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second); nodes[pair.first] = new Node(*pair.second);
@ -29,9 +29,6 @@ namespace bayesnet {
return; return;
} }
nodes[name] = new Node(name, numStates); nodes[name] = new Node(name, numStates);
if (root == nullptr) {
root = nodes[name];
}
} }
vector<string> Network::getFeatures() vector<string> Network::getFeatures()
{ {
@ -45,17 +42,6 @@ namespace bayesnet {
{ {
return className; return className;
} }
void Network::setRoot(string name)
{
if (nodes.find(name) == nodes.end()) {
throw invalid_argument("Node " + name + " does not exist");
}
root = nodes[name];
}
Node* Network::getRoot()
{
return root;
}
bool Network::isCyclic(const string& nodeId, unordered_set<string>& visited, unordered_set<string>& recStack) bool Network::isCyclic(const string& nodeId, unordered_set<string>& visited, unordered_set<string>& recStack)
{ {
if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
@ -227,18 +213,15 @@ namespace bayesnet {
vector<double> result(classNumStates, 0.0); vector<double> result(classNumStates, 0.0);
vector<thread> threads; vector<thread> threads;
mutex mtx; mutex mtx;
for (int i = 0; i < classNumStates; ++i) { for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back([this, &result, &evidence, i, &mtx]() { threads.emplace_back([this, &result, &evidence, i, &mtx]() {
auto completeEvidence = map<string, int>(evidence); auto completeEvidence = map<string, int>(evidence);
completeEvidence[getClassName()] = i; completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence); double factor = computeFactor(completeEvidence);
lock_guard<mutex> lock(mtx); lock_guard<mutex> lock(mtx);
result[i] = factor; result[i] = factor;
}); });
} }
for (auto& thread : threads) { for (auto& thread : threads) {
thread.join(); thread.join();
} }

View File

@ -10,7 +10,6 @@ namespace bayesnet {
private: private:
map<string, Node*> nodes; map<string, Node*> nodes;
map<string, vector<int>> dataset; map<string, vector<int>> dataset;
Node* root;
float maxThreads; float maxThreads;
int classNumStates; int classNumStates;
vector<string> features; vector<string> features;
@ -34,8 +33,6 @@ namespace bayesnet {
int getClassNumStates(); int getClassNumStates();
string getClassName(); string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&); void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void setRoot(string);
Node* getRoot();
vector<int> predict(const vector<vector<int>>&); vector<int> predict(const vector<vector<int>>&);
vector<pair<int, double>> predict_proba(const vector<vector<int>>&); vector<pair<int, double>> predict_proba(const vector<vector<int>>&);
double score(const vector<vector<int>>&, const vector<int>&); double score(const vector<vector<int>>&, const vector<int>&);