Refactor ExactInference
This commit is contained in:
parent
3de1967b3e
commit
9b70708afb
@ -1,2 +1,2 @@
|
||||
add_library(BayesNet Network.cc Node.cc ExactInference.cc)
|
||||
add_library(BayesNet Network.cc Node.cc)
|
||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -1,31 +0,0 @@
|
||||
#include "ExactInference.h"
|
||||
|
||||
namespace bayesnet {
|
||||
ExactInference::ExactInference(Network& net) : network(net) {}
|
||||
double ExactInference::computeFactor(map<string, int>& completeEvidence)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (auto node : network.getNodes()) {
|
||||
result *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
vector<double> ExactInference::variableElimination(map<string, int>& evidence)
|
||||
{
|
||||
vector<double> result;
|
||||
string candidate;
|
||||
int classNumStates = network.getClassNumStates();
|
||||
for (int i = 0; i < classNumStates; ++i) {
|
||||
result.push_back(1.0);
|
||||
auto complete_evidence = map<string, int>(evidence);
|
||||
complete_evidence[network.getClassName()] = i;
|
||||
result[i] = computeFactor(complete_evidence);
|
||||
}
|
||||
// Normalize result
|
||||
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
||||
for (int i = 0; i < result.size(); ++i) {
|
||||
result[i] /= sum;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
#ifndef EXACTINFERENCE_H
|
||||
#define EXACTINFERENCE_H
|
||||
#include "Network.h"
|
||||
#include "Node.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
namespace bayesnet {
|
||||
class ExactInference {
|
||||
private:
|
||||
Network network;
|
||||
double computeFactor(map<string, int>&);
|
||||
public:
|
||||
ExactInference(Network&);
|
||||
vector<double> variableElimination(map<string, int>&);
|
||||
};
|
||||
}
|
||||
#endif
|
@ -1,5 +1,4 @@
|
||||
#include "Network.h"
|
||||
#include "ExactInference.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
@ -157,12 +156,11 @@ namespace bayesnet {
|
||||
throw invalid_argument("Sample size (" + to_string(sample.size()) +
|
||||
") does not match the number of features (" + to_string(features.size()) + ")");
|
||||
}
|
||||
auto inference = ExactInference(*this);
|
||||
map<string, int> evidence;
|
||||
for (int i = 0; i < sample.size(); ++i) {
|
||||
evidence[features[i]] = sample[i];
|
||||
}
|
||||
vector<double> classProbabilities = inference.variableElimination(evidence);
|
||||
vector<double> classProbabilities = exactInference(evidence);
|
||||
|
||||
// Find the class with the maximum posterior probability
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
@ -171,4 +169,29 @@ namespace bayesnet {
|
||||
|
||||
return make_pair(predictedClass, maxProbability);
|
||||
}
|
||||
double Network::computeFactor(map<string, int>& completeEvidence)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (auto node : getNodes()) {
|
||||
result *= node.second->getFactorValue(completeEvidence);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
vector<double> Network::exactInference(map<string, int>& evidence)
|
||||
{
|
||||
vector<double> result;
|
||||
int classNumStates = getClassNumStates();
|
||||
for (int i = 0; i < classNumStates; ++i) {
|
||||
result.push_back(1.0);
|
||||
auto complete_evidence = map<string, int>(evidence);
|
||||
complete_evidence[getClassName()] = i;
|
||||
result[i] = computeFactor(complete_evidence);
|
||||
}
|
||||
// Normalize result
|
||||
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
||||
for (int i = 0; i < result.size(); ++i) {
|
||||
result[i] /= sum;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ namespace bayesnet {
|
||||
int laplaceSmoothing;
|
||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||
pair<int, double> predict_sample(const vector<int>&);
|
||||
vector<double> exactInference(map<string, int>&);
|
||||
double computeFactor(map<string, int>&);
|
||||
public:
|
||||
Network();
|
||||
Network(int);
|
||||
|
Loading…
Reference in New Issue
Block a user