Refactor to remove root node of network
This commit is contained in:
parent
0b33c6c04a
commit
71b88e2c65
@ -107,9 +107,6 @@ void showNodesInfo(bayesnet::Network& network, string className)
|
|||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << "Root: " << network.getRoot()->getName() << endl;
|
|
||||||
network.setRoot(className);
|
|
||||||
cout << "Now Root should be class: " << network.getRoot()->getName() << endl;
|
|
||||||
}
|
}
|
||||||
void showCPDS(bayesnet::Network& network)
|
void showCPDS(bayesnet::Network& network)
|
||||||
{
|
{
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
|
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
|
||||||
Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||||
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
|
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
|
||||||
{
|
{
|
||||||
for (auto& pair : other.nodes) {
|
for (auto& pair : other.nodes) {
|
||||||
nodes[pair.first] = new Node(*pair.second);
|
nodes[pair.first] = new Node(*pair.second);
|
||||||
@ -29,9 +29,6 @@ namespace bayesnet {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
nodes[name] = new Node(name, numStates);
|
nodes[name] = new Node(name, numStates);
|
||||||
if (root == nullptr) {
|
|
||||||
root = nodes[name];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
vector<string> Network::getFeatures()
|
vector<string> Network::getFeatures()
|
||||||
{
|
{
|
||||||
@ -45,17 +42,6 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
return className;
|
return className;
|
||||||
}
|
}
|
||||||
void Network::setRoot(string name)
|
|
||||||
{
|
|
||||||
if (nodes.find(name) == nodes.end()) {
|
|
||||||
throw invalid_argument("Node " + name + " does not exist");
|
|
||||||
}
|
|
||||||
root = nodes[name];
|
|
||||||
}
|
|
||||||
Node* Network::getRoot()
|
|
||||||
{
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
bool Network::isCyclic(const string& nodeId, unordered_set<string>& visited, unordered_set<string>& recStack)
|
bool Network::isCyclic(const string& nodeId, unordered_set<string>& visited, unordered_set<string>& recStack)
|
||||||
{
|
{
|
||||||
if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
|
if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
|
||||||
@ -227,18 +213,15 @@ namespace bayesnet {
|
|||||||
vector<double> result(classNumStates, 0.0);
|
vector<double> result(classNumStates, 0.0);
|
||||||
vector<thread> threads;
|
vector<thread> threads;
|
||||||
mutex mtx;
|
mutex mtx;
|
||||||
|
|
||||||
for (int i = 0; i < classNumStates; ++i) {
|
for (int i = 0; i < classNumStates; ++i) {
|
||||||
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
|
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
|
||||||
auto completeEvidence = map<string, int>(evidence);
|
auto completeEvidence = map<string, int>(evidence);
|
||||||
completeEvidence[getClassName()] = i;
|
completeEvidence[getClassName()] = i;
|
||||||
double factor = computeFactor(completeEvidence);
|
double factor = computeFactor(completeEvidence);
|
||||||
|
|
||||||
lock_guard<mutex> lock(mtx);
|
lock_guard<mutex> lock(mtx);
|
||||||
result[i] = factor;
|
result[i] = factor;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ namespace bayesnet {
|
|||||||
private:
|
private:
|
||||||
map<string, Node*> nodes;
|
map<string, Node*> nodes;
|
||||||
map<string, vector<int>> dataset;
|
map<string, vector<int>> dataset;
|
||||||
Node* root;
|
|
||||||
float maxThreads;
|
float maxThreads;
|
||||||
int classNumStates;
|
int classNumStates;
|
||||||
vector<string> features;
|
vector<string> features;
|
||||||
@ -34,8 +33,6 @@ namespace bayesnet {
|
|||||||
int getClassNumStates();
|
int getClassNumStates();
|
||||||
string getClassName();
|
string getClassName();
|
||||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||||
void setRoot(string);
|
|
||||||
Node* getRoot();
|
|
||||||
vector<int> predict(const vector<vector<int>>&);
|
vector<int> predict(const vector<vector<int>>&);
|
||||||
vector<pair<int, double>> predict_proba(const vector<vector<int>>&);
|
vector<pair<int, double>> predict_proba(const vector<vector<int>>&);
|
||||||
double score(const vector<vector<int>>&, const vector<int>&);
|
double score(const vector<vector<int>>&, const vector<int>&);
|
||||||
|
Loading…
Reference in New Issue
Block a user