diff --git a/README.md b/README.md index 4fd0db2..98a7732 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ The process is repeated until there are no more variables to eliminate. ## Code for combination +```cpp // Combinations of length 2 vector combinations(vector source) { @@ -37,3 +38,75 @@ vector combinations(vector source) } return result; } +``` + +## Code for Variable Elimination + +```cpp +// Variable Elimination +vector variableElimination(vector source, map> graph) +{ + vector variables = source; + vector factors = source; + while (variables.size() > 0) { + string variable = minFill(variables, graph); + vector neighbors = graph[variable]; + vector combinations = combinations(neighbors); + vector factorsToMultiply; + for (int i = 0; i < factors.size(); ++i) { + string factor = factors[i]; + for (int j = 0; j < combinations.size(); ++j) { + string combination = combinations[j]; + if (factor.find(combination) != string::npos) { + factorsToMultiply.push_back(factor); + break; + } + } + } + string newFactor = multiplyFactors(factorsToMultiply); + factors.push_back(newFactor); + variables.erase(remove(variables.begin(), variables.end(), variable), variables.end()); + } + return factors; +} +``` + +## Network copy constructor + +```cpp +// Network copy constructor +Network::Network(const Network& network) +{ + this->variables = network.variables; + this->factors = network.factors; + this->graph = network.graph; +} +``` + +## Code for MinFill + +```cpp +// MinFill +string minFill(vector source, map> graph) +{ + string result; + int min = INT_MAX; + for (int i = 0; i < source.size(); ++i) { + string temp = source[i]; + int count = 0; + vector neighbors = graph[temp]; + vector combinations = combinations(neighbors); + for (int j = 0; j < combinations.size(); ++j) { + string combination = combinations[j]; + if (graph[combination].size() == 0) { + count++; + } + } + if (count < min) { + min = count; + result = temp; + } + } + return result; +} +``` diff --git a/src/Network.cc b/src/Network.cc index 9f926b3..c94c291 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -2,6 +2,12 @@ namespace bayesnet { Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className("") {} Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className("") {} + Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className) + { + for (auto& pair : other.nodes) { + nodes[pair.first] = new Node(*pair.second); + } + } Network::~Network() { for (auto& pair : nodes) { diff --git a/src/Network.h b/src/Network.h index a155fed..3f907dd 100644 --- a/src/Network.h +++ b/src/Network.h @@ -19,6 +19,7 @@ namespace bayesnet { public: Network(); Network(int); + Network(Network&); ~Network(); void addNode(string, int); void addEdge(const string, const string);