diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a24c987..e2b3ce9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc ExactInference.cc) +add_library(BayesNet Network.cc Node.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/ExactInference.cc b/src/ExactInference.cc deleted file mode 100644 index 82c8d55..0000000 --- a/src/ExactInference.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "ExactInference.h" - -namespace bayesnet { - ExactInference::ExactInference(Network& net) : network(net) {} - double ExactInference::computeFactor(map& completeEvidence) - { - double result = 1.0; - for (auto node : network.getNodes()) { - result *= node.second->getFactorValue(completeEvidence); - } - return result; - } - vector ExactInference::variableElimination(map& evidence) - { - vector result; - string candidate; - int classNumStates = network.getClassNumStates(); - for (int i = 0; i < classNumStates; ++i) { - result.push_back(1.0); - auto complete_evidence = map(evidence); - complete_evidence[network.getClassName()] = i; - result[i] = computeFactor(complete_evidence); - } - // Normalize result - auto sum = accumulate(result.begin(), result.end(), 0.0); - for (int i = 0; i < result.size(); ++i) { - result[i] /= sum; - } - return result; - } -} \ No newline at end of file diff --git a/src/ExactInference.h b/src/ExactInference.h deleted file mode 100644 index cc838f5..0000000 --- a/src/ExactInference.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef EXACTINFERENCE_H -#define EXACTINFERENCE_H -#include "Network.h" -#include "Node.h" -#include -#include -#include -using namespace std; - -namespace bayesnet { - class ExactInference { - private: - Network network; - double computeFactor(map&); - public: - ExactInference(Network&); - vector variableElimination(map&); - }; -} -#endif \ No newline at end of file diff --git a/src/Network.cc b/src/Network.cc index f372c8c..14f3c3c 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -1,5 +1,4 @@ #include "Network.h" -#include "ExactInference.h" namespace bayesnet { Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0) {} Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className(""), classNumStates(0) {} @@ -157,12 +156,11 @@ namespace bayesnet { throw invalid_argument("Sample size (" + to_string(sample.size()) + ") does not match the number of features (" + to_string(features.size()) + ")"); } - auto inference = ExactInference(*this); map evidence; for (int i = 0; i < sample.size(); ++i) { evidence[features[i]] = sample[i]; } - vector classProbabilities = inference.variableElimination(evidence); + vector classProbabilities = exactInference(evidence); // Find the class with the maximum posterior probability auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); @@ -171,4 +169,29 @@ namespace bayesnet { return make_pair(predictedClass, maxProbability); } + double Network::computeFactor(map& completeEvidence) + { + double result = 1.0; + for (auto node : getNodes()) { + result *= node.second->getFactorValue(completeEvidence); + } + return result; + } + vector Network::exactInference(map& evidence) + { + vector result; + int classNumStates = getClassNumStates(); + for (int i = 0; i < classNumStates; ++i) { + result.push_back(1.0); + auto complete_evidence = map(evidence); + complete_evidence[getClassName()] = i; + result[i] = computeFactor(complete_evidence); + } + // Normalize result + auto sum = accumulate(result.begin(), result.end(), 0.0); + for (int i = 0; i < result.size(); ++i) { + result[i] /= sum; + } + return result; + } } diff --git a/src/Network.h b/src/Network.h index 76050de..2c01e12 100644 --- a/src/Network.h +++ b/src/Network.h @@ -17,6 +17,8 @@ namespace bayesnet { int laplaceSmoothing; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); pair predict_sample(const vector&); + vector exactInference(map&); + double computeFactor(map&); public: Network(); Network(int);