Refactor ExactInference

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 11:01:58 +02:00
parent 3de1967b3e
commit 9b70708afb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 29 additions and 55 deletions

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc ExactInference.cc)
add_library(BayesNet Network.cc Node.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -1,31 +0,0 @@
#include "ExactInference.h"
namespace bayesnet {
ExactInference::ExactInference(Network& net) : network(net) {}
double ExactInference::computeFactor(map<string, int>& completeEvidence)
{
double result = 1.0;
for (auto node : network.getNodes()) {
result *= node.second->getFactorValue(completeEvidence);
}
return result;
}
vector<double> ExactInference::variableElimination(map<string, int>& evidence)
{
vector<double> result;
string candidate;
int classNumStates = network.getClassNumStates();
for (int i = 0; i < classNumStates; ++i) {
result.push_back(1.0);
auto complete_evidence = map<string, int>(evidence);
complete_evidence[network.getClassName()] = i;
result[i] = computeFactor(complete_evidence);
}
// Normalize result
auto sum = accumulate(result.begin(), result.end(), 0.0);
for (int i = 0; i < result.size(); ++i) {
result[i] /= sum;
}
return result;
}
}

View File

@ -1,20 +0,0 @@
#ifndef EXACTINFERENCE_H
#define EXACTINFERENCE_H
#include "Network.h"
#include "Node.h"
#include <map>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class ExactInference {
private:
Network network;
double computeFactor(map<string, int>&);
public:
ExactInference(Network&);
vector<double> variableElimination(map<string, int>&);
};
}
#endif

View File

@ -1,5 +1,4 @@
#include "Network.h"
#include "ExactInference.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
@ -157,12 +156,11 @@ namespace bayesnet {
throw invalid_argument("Sample size (" + to_string(sample.size()) +
") does not match the number of features (" + to_string(features.size()) + ")");
}
auto inference = ExactInference(*this);
map<string, int> evidence;
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
}
vector<double> classProbabilities = inference.variableElimination(evidence);
vector<double> classProbabilities = exactInference(evidence);
// Find the class with the maximum posterior probability
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
@ -171,4 +169,29 @@ namespace bayesnet {
return make_pair(predictedClass, maxProbability);
}
double Network::computeFactor(map<string, int>& completeEvidence)
{
double result = 1.0;
for (auto node : getNodes()) {
result *= node.second->getFactorValue(completeEvidence);
}
return result;
}
vector<double> Network::exactInference(map<string, int>& evidence)
{
vector<double> result;
int classNumStates = getClassNumStates();
for (int i = 0; i < classNumStates; ++i) {
result.push_back(1.0);
auto complete_evidence = map<string, int>(evidence);
complete_evidence[getClassName()] = i;
result[i] = computeFactor(complete_evidence);
}
// Normalize result
auto sum = accumulate(result.begin(), result.end(), 0.0);
for (int i = 0; i < result.size(); ++i) {
result[i] /= sum;
}
return result;
}
}

View File

@ -17,6 +17,8 @@ namespace bayesnet {
int laplaceSmoothing;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
pair<int, double> predict_sample(const vector<int>&);
vector<double> exactInference(map<string, int>&);
double computeFactor(map<string, int>&);
public:
Network();
Network(int);