Refactor ExactInference
This commit is contained in:
parent
3de1967b3e
commit
9b70708afb
@ -1,2 +1,2 @@
|
|||||||
add_library(BayesNet Network.cc Node.cc ExactInference.cc)
|
add_library(BayesNet Network.cc Node.cc)
|
||||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
@ -1,31 +0,0 @@
|
|||||||
#include "ExactInference.h"
|
|
||||||
|
|
||||||
namespace bayesnet {
|
|
||||||
ExactInference::ExactInference(Network& net) : network(net) {}
|
|
||||||
double ExactInference::computeFactor(map<string, int>& completeEvidence)
|
|
||||||
{
|
|
||||||
double result = 1.0;
|
|
||||||
for (auto node : network.getNodes()) {
|
|
||||||
result *= node.second->getFactorValue(completeEvidence);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
vector<double> ExactInference::variableElimination(map<string, int>& evidence)
|
|
||||||
{
|
|
||||||
vector<double> result;
|
|
||||||
string candidate;
|
|
||||||
int classNumStates = network.getClassNumStates();
|
|
||||||
for (int i = 0; i < classNumStates; ++i) {
|
|
||||||
result.push_back(1.0);
|
|
||||||
auto complete_evidence = map<string, int>(evidence);
|
|
||||||
complete_evidence[network.getClassName()] = i;
|
|
||||||
result[i] = computeFactor(complete_evidence);
|
|
||||||
}
|
|
||||||
// Normalize result
|
|
||||||
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
|
||||||
for (int i = 0; i < result.size(); ++i) {
|
|
||||||
result[i] /= sum;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
#ifndef EXACTINFERENCE_H
|
|
||||||
#define EXACTINFERENCE_H
|
|
||||||
#include "Network.h"
|
|
||||||
#include "Node.h"
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace bayesnet {
|
|
||||||
class ExactInference {
|
|
||||||
private:
|
|
||||||
Network network;
|
|
||||||
double computeFactor(map<string, int>&);
|
|
||||||
public:
|
|
||||||
ExactInference(Network&);
|
|
||||||
vector<double> variableElimination(map<string, int>&);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
#endif
|
|
@ -1,5 +1,4 @@
|
|||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
#include "ExactInference.h"
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||||
@ -157,12 +156,11 @@ namespace bayesnet {
|
|||||||
throw invalid_argument("Sample size (" + to_string(sample.size()) +
|
throw invalid_argument("Sample size (" + to_string(sample.size()) +
|
||||||
") does not match the number of features (" + to_string(features.size()) + ")");
|
") does not match the number of features (" + to_string(features.size()) + ")");
|
||||||
}
|
}
|
||||||
auto inference = ExactInference(*this);
|
|
||||||
map<string, int> evidence;
|
map<string, int> evidence;
|
||||||
for (int i = 0; i < sample.size(); ++i) {
|
for (int i = 0; i < sample.size(); ++i) {
|
||||||
evidence[features[i]] = sample[i];
|
evidence[features[i]] = sample[i];
|
||||||
}
|
}
|
||||||
vector<double> classProbabilities = inference.variableElimination(evidence);
|
vector<double> classProbabilities = exactInference(evidence);
|
||||||
|
|
||||||
// Find the class with the maximum posterior probability
|
// Find the class with the maximum posterior probability
|
||||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||||
@ -171,4 +169,29 @@ namespace bayesnet {
|
|||||||
|
|
||||||
return make_pair(predictedClass, maxProbability);
|
return make_pair(predictedClass, maxProbability);
|
||||||
}
|
}
|
||||||
|
double Network::computeFactor(map<string, int>& completeEvidence)
|
||||||
|
{
|
||||||
|
double result = 1.0;
|
||||||
|
for (auto node : getNodes()) {
|
||||||
|
result *= node.second->getFactorValue(completeEvidence);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
vector<double> Network::exactInference(map<string, int>& evidence)
|
||||||
|
{
|
||||||
|
vector<double> result;
|
||||||
|
int classNumStates = getClassNumStates();
|
||||||
|
for (int i = 0; i < classNumStates; ++i) {
|
||||||
|
result.push_back(1.0);
|
||||||
|
auto complete_evidence = map<string, int>(evidence);
|
||||||
|
complete_evidence[getClassName()] = i;
|
||||||
|
result[i] = computeFactor(complete_evidence);
|
||||||
|
}
|
||||||
|
// Normalize result
|
||||||
|
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
||||||
|
for (int i = 0; i < result.size(); ++i) {
|
||||||
|
result[i] /= sum;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,8 @@ namespace bayesnet {
|
|||||||
int laplaceSmoothing;
|
int laplaceSmoothing;
|
||||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||||
pair<int, double> predict_sample(const vector<int>&);
|
pair<int, double> predict_sample(const vector<int>&);
|
||||||
|
vector<double> exactInference(map<string, int>&);
|
||||||
|
double computeFactor(map<string, int>&);
|
||||||
public:
|
public:
|
||||||
Network();
|
Network();
|
||||||
Network(int);
|
Network(int);
|
||||||
|
Loading…
Reference in New Issue
Block a user