Refactor ExactInference
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
#include "Network.h"
|
||||
#include "ExactInference.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
@@ -157,12 +156,11 @@ namespace bayesnet {
|
||||
throw invalid_argument("Sample size (" + to_string(sample.size()) +
|
||||
") does not match the number of features (" + to_string(features.size()) + ")");
|
||||
}
|
||||
auto inference = ExactInference(*this);
|
||||
map<string, int> evidence;
|
||||
for (int i = 0; i < sample.size(); ++i) {
|
||||
evidence[features[i]] = sample[i];
|
||||
}
|
||||
vector<double> classProbabilities = inference.variableElimination(evidence);
|
||||
vector<double> classProbabilities = exactInference(evidence);
|
||||
|
||||
// Find the class with the maximum posterior probability
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
@@ -171,4 +169,29 @@ namespace bayesnet {
|
||||
|
||||
return make_pair(predictedClass, maxProbability);
|
||||
}
|
||||
double Network::computeFactor(map<string, int>& completeEvidence)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (auto node : getNodes()) {
|
||||
result *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
vector<double> Network::exactInference(map<string, int>& evidence)
|
||||
{
|
||||
vector<double> result;
|
||||
int classNumStates = getClassNumStates();
|
||||
for (int i = 0; i < classNumStates; ++i) {
|
||||
result.push_back(1.0);
|
||||
auto complete_evidence = map<string, int>(evidence);
|
||||
complete_evidence[getClassName()] = i;
|
||||
result[i] = computeFactor(complete_evidence);
|
||||
}
|
||||
// Normalize result
|
||||
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
||||
for (int i = 0; i < result.size(); ++i) {
|
||||
result[i] /= sum;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user