From ba08b8dd3d597ab414a220bbd67dd9258cf06286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 5 Jul 2023 18:38:54 +0200 Subject: [PATCH] Inference working --- .vscode/settings.json | 4 +- data/iris.net | 2 +- sample/test.cc | 47 ++++++++++++++++++----- src/CMakeLists.txt | 2 +- src/ExactInference.cc | 67 ++++++++------------------------- src/ExactInference.h | 11 +----- src/Factor.cc | 87 ------------------------------------------- src/Factor.h | 31 --------------- src/Network.cc | 47 +++++++---------------- src/Network.h | 3 ++ src/Node.cc | 49 +++++++++++++++--------- src/Node.h | 14 +++---- 12 files changed, 114 insertions(+), 250 deletions(-) delete mode 100644 src/Factor.cc delete mode 100644 src/Factor.h diff --git a/.vscode/settings.json b/.vscode/settings.json index 4d408cb..ef91e92 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -86,7 +86,9 @@ "*.tcc": "cpp", "functional": "cpp", "iterator": "cpp", - "memory_resource": "cpp" + "memory_resource": "cpp", + "format": "cpp", + "valarray": "cpp" }, "cmake.configureOnOpen": false, "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" diff --git a/data/iris.net b/data/iris.net index 7ff8cbf..d56688a 100644 --- a/data/iris.net +++ b/data/iris.net @@ -2,4 +2,4 @@ class sepallength class sepalwidth class petallength class petalwidth -# petalwidth petallength \ No newline at end of file +petalwidth petallength \ No newline at end of file diff --git a/sample/test.cc b/sample/test.cc index 4784189..334e6ba 100644 --- a/sample/test.cc +++ b/sample/test.cc @@ -21,20 +21,47 @@ // std::cout << t << std::endl; // } #include - +#include +#include +#include +using namespace std; int main() { - //torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example //int i = 3, j = 1, k = 2; // Indices for the cell you want to update // Print original tensor - torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example - std::cout << t << std::endl; - std::cout << "sum(0)" << std::endl; - std::cout << t.sum(0) << std::endl; - std::cout << "sum(1)" << std::endl; - std::cout << t.sum(1) << std::endl; - std::cout << "Normalized" << std::endl; - std::cout << t / t.sum(0) << std::endl; + // torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example + auto variables = vector{ "A", "B" }; + auto cardinalities = vector{ 5, 4 }; + torch::Tensor values = torch::rand({ 5, 4 }); + auto candidate = "B"; + vector newVariables; + vector newCardinalities; + for (int i = 0; i < variables.size(); i++) { + if (variables[i] != candidate) { + newVariables.push_back(variables[i]); + newCardinalities.push_back(cardinalities[i]); + } + } + torch::Tensor newValues = values.sum(1); + cout << "original values" << endl; + cout << values << endl; + cout << "newValues" << endl; + cout << newValues << endl; + cout << "newVariables" << endl; + for (auto& variable : newVariables) { + cout << variable << endl; + } + cout << "newCardinalities" << endl; + for (auto& cardinality : newCardinalities) { + cout << cardinality << endl; + } + // std::cout << t << std::endl; + // std::cout << "sum(0)" << std::endl; + // std::cout << t.sum(0) << std::endl; + // std::cout << "sum(1)" << std::endl; + // std::cout << t.sum(1) << std::endl; + // std::cout << "Normalized" << std::endl; + // std::cout << t / t.sum(0) << std::endl; // New value // torch::Tensor new_val = torch::tensor(10.0f); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c27057c..a24c987 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc ExactInference.cc Factor.cc) +add_library(BayesNet Network.cc Node.cc ExactInference.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/ExactInference.cc b/src/ExactInference.cc index 0040447..82c8d55 100644 --- a/src/ExactInference.cc +++ b/src/ExactInference.cc @@ -1,65 +1,30 @@ #include "ExactInference.h" namespace bayesnet { - ExactInference::ExactInference(Network& net) : network(net), evidence(map()), candidates(net.getFeatures()) {} - void ExactInference::setEvidence(const map& evidence) - { - this->evidence = evidence; - } - ExactInference::~ExactInference() - { - for (auto& factor : factors) { - delete factor; - } - } - void ExactInference::buildFactors() + ExactInference::ExactInference(Network& net) : network(net) {} + double ExactInference::computeFactor(map& completeEvidence) { + double result = 1.0; for (auto node : network.getNodes()) { - factors.push_back(node.second->toFactor()); - } - } - string ExactInference::nextCandidate() - { - string result = ""; - map nodes = network.getNodes(); - int minFill = INT_MAX; - for (auto candidate : candidates) { - unsigned fill = nodes[candidate]->minFill(); - if (fill < minFill) { - minFill = fill; - result = candidate; - } + result *= node.second->getFactorValue(completeEvidence); } return result; } - vector ExactInference::variableElimination() + vector ExactInference::variableElimination(map& evidence) { vector result; string candidate; - buildFactors(); - // Eliminate evidence - while ((candidate = nextCandidate()) != "") { - // Erase candidate from candidates (Erase–remove idiom) - candidates.erase(remove(candidates.begin(), candidates.end(), candidate), candidates.end()); - // sum-product variable elimination algorithm as explained in the book probabilistic graphical models - // 1. Multiply all factors containing the variable - vector factorsToMultiply; - for (auto factor : factors) { - if (factor->contains(candidate)) { - factorsToMultiply.push_back(factor); - } - } - Factor* product = Factor::product(factorsToMultiply); - // 2. Sum out the variable - Factor* sum = product->sumOut(candidate); - // 3. Remove factors containing the variable - for (auto factor : factorsToMultiply) { - factors.erase(remove(factors.begin(), factors.end(), factor), factors.end()); - delete factor; - } - // 4. Add the resulting factor to the list of factors - factors.push_back(sum); - + int classNumStates = network.getClassNumStates(); + for (int i = 0; i < classNumStates; ++i) { + result.push_back(1.0); + auto complete_evidence = map(evidence); + complete_evidence[network.getClassName()] = i; + result[i] = computeFactor(complete_evidence); + } + // Normalize result + auto sum = accumulate(result.begin(), result.end(), 0.0); + for (int i = 0; i < result.size(); ++i) { + result[i] /= sum; } return result; } diff --git a/src/ExactInference.h b/src/ExactInference.h index 87cdb6f..cc838f5 100644 --- a/src/ExactInference.h +++ b/src/ExactInference.h @@ -1,7 +1,6 @@ #ifndef EXACTINFERENCE_H #define EXACTINFERENCE_H #include "Network.h" -#include "Factor.h" #include "Node.h" #include #include @@ -12,16 +11,10 @@ namespace bayesnet { class ExactInference { private: Network network; - map evidence; - vector factors; - vector candidates; // variables to be removed - void buildFactors(); - string nextCandidate(); // Return the next variable to eliminate using MinFill criterion + double computeFactor(map&); public: ExactInference(Network&); - ~ExactInference(); - void setEvidence(const map&); - vector variableElimination(); + vector variableElimination(map&); }; } #endif \ No newline at end of file diff --git a/src/Factor.cc b/src/Factor.cc deleted file mode 100644 index 247d9ec..0000000 --- a/src/Factor.cc +++ /dev/null @@ -1,87 +0,0 @@ -#include "Factor.h" -#include -#include - -using namespace std; - -namespace bayesnet { - Factor::Factor(vector& variables, vector& cardinalities, torch::Tensor& values) : variables(variables), cardinalities(cardinalities), values(values) {} - Factor::~Factor() = default; - Factor::Factor(const Factor& other) : variables(other.variables), cardinalities(other.cardinalities), values(other.values) {} - Factor& Factor::operator=(const Factor& other) - { - if (this != &other) { - variables = other.variables; - cardinalities = other.cardinalities; - values = other.values; - } - return *this; - } - void Factor::setVariables(vector& variables) - { - this->variables = variables; - } - void Factor::setCardinalities(vector& cardinalities) - { - this->cardinalities = cardinalities; - } - void Factor::setValues(torch::Tensor& values) - { - this->values = values; - } - vector& Factor::getVariables() - { - return variables; - } - vector& Factor::getCardinalities() - { - return cardinalities; - } - torch::Tensor& Factor::getValues() - { - return values; - } - bool Factor::contains(string& variable) - { - for (int i = 0; i < variables.size(); i++) { - if (variables[i] == variable) { - return true; - } - } - return false; - } - Factor* Factor::sumOut(string& candidate) - { - vector newVariables; - vector newCardinalities; - for (int i = 0; i < variables.size(); i++) { - if (variables[i] != candidate) { - newVariables.push_back(variables[i]); - newCardinalities.push_back(cardinalities[i]); - } - } - torch::Tensor newValues = values.sum(0); - return new Factor(newVariables, newCardinalities, newValues); - } - Factor* Factor::product(vector& factors) - { - vector newVariables; - vector newCardinalities; - for (auto factor : factors) { - vector variables = factor->getVariables(); - for (auto idx = 0; idx < variables.size(); ++idx) { - string variable = variables[idx]; - if (find(newVariables.begin(), newVariables.end(), variable) == newVariables.end()) { - newVariables.push_back(variable); - newCardinalities.push_back(factor->getCardinalities()[idx]); - } - } - } - torch::Tensor newValues = factors[0]->getValues(); - for (int i = 1; i < factors.size(); i++) { - newValues = newValues.matmul(factors[i]->getValues()); - } - return new Factor(newVariables, newCardinalities, newValues); - } - -} \ No newline at end of file diff --git a/src/Factor.h b/src/Factor.h deleted file mode 100644 index f98dc48..0000000 --- a/src/Factor.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef FACTOR_H -#define FACTOR_H -#include -#include -#include -using namespace std; - -namespace bayesnet { - class Factor { - private: - vector variables; - vector cardinalities; - torch::Tensor values; - public: - Factor(vector&, vector&, torch::Tensor&); - ~Factor(); - Factor(const Factor&); - Factor& operator=(const Factor&); - void setVariables(vector&); - void setCardinalities(vector&); - void setValues(torch::Tensor&); - vector& getVariables(); - vector& getCardinalities(); - bool contains(string&); - torch::Tensor& getValues(); - static Factor* product(vector&); - Factor* sumOut(string&); - - }; -} -#endif \ No newline at end of file diff --git a/src/Network.cc b/src/Network.cc index 30956a1..f372c8c 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -1,9 +1,9 @@ #include "Network.h" #include "ExactInference.h" namespace bayesnet { - Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className("") {} - Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className("") {} - Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className) + Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0) {} + Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className(""), classNumStates(0) {} + Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()) { for (auto& pair : other.nodes) { nodes[pair.first] = new Node(*pair.second); @@ -31,6 +31,14 @@ namespace bayesnet { { return features; } + int Network::getClassNumStates() + { + return classNumStates; + } + string Network::getClassName() + { + return className; + } void Network::setRoot(string name) { if (nodes.find(name) == nodes.end()) { @@ -93,6 +101,7 @@ namespace bayesnet { this->dataset[featureNames[i]] = dataset[i]; } this->dataset[className] = labels; + this->classNumStates = *max_element(labels.begin(), labels.end()) + 1; estimateParameters(); } @@ -100,29 +109,7 @@ namespace bayesnet { { auto dimensions = vector(); for (auto [name, node] : nodes) { - // Get dimensions of the CPT - dimensions.clear(); - dimensions.push_back(node->getNumStates()); - for (auto father : node->getParents()) { - dimensions.push_back(father->getNumStates()); - } - auto length = dimensions.size(); - // Create a tensor of zeros with the dimensions of the CPT - torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; - // Fill table with counts - for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { - torch::List> coordinates; - coordinates.push_back(torch::tensor(dataset[name][n_sample])); - for (auto father : node->getParents()) { - coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample])); - } - // Increment the count of the corresponding coordinate - cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1); - } - // Normalize the counts - cpt = cpt / cpt.sum(0); - // store thre resulting cpt in the node - node->setCPT(cpt); + node->computeCPT(dataset, laplaceSmoothing); } } @@ -175,14 +162,8 @@ namespace bayesnet { for (int i = 0; i < sample.size(); ++i) { evidence[features[i]] = sample[i]; } - inference.setEvidence(evidence); - vector classProbabilities = inference.variableElimination(); + vector classProbabilities = inference.variableElimination(evidence); - // Normalize the probabilities to sum to 1 - double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0); - for (double& prob : classProbabilities) { - prob /= sum; - } // Find the class with the maximum posterior probability auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); int predictedClass = distance(classProbabilities.begin(), maxElem); diff --git a/src/Network.h b/src/Network.h index aa0285b..76050de 100644 --- a/src/Network.h +++ b/src/Network.h @@ -11,6 +11,7 @@ namespace bayesnet { map nodes; map> dataset; Node* root; + int classNumStates; vector features; string className; int laplaceSmoothing; @@ -25,6 +26,8 @@ namespace bayesnet { void addEdge(const string, const string); map& getNodes(); vector getFeatures(); + int getClassNumStates(); + string getClassName(); void fit(const vector>&, const vector&, const vector&, const string&); void estimateParameters(); void setRoot(string); diff --git a/src/Node.cc b/src/Node.cc index ab2e193..075353d 100644 --- a/src/Node.cc +++ b/src/Node.cc @@ -1,10 +1,9 @@ #include "Node.h" namespace bayesnet { - int Node::next_id = 0; Node::Node(const std::string& name, int numStates) - : id(next_id++), name(name), numStates(numStates), cpt(torch::Tensor()), parents(vector()), children(vector()) + : name(name), numStates(numStates), cpTable(torch::Tensor()), parents(vector()), children(vector()) { } @@ -47,11 +46,7 @@ namespace bayesnet { } torch::Tensor& Node::getCPT() { - return cpt; - } - void Node::setCPT(const torch::Tensor& cpt) - { - this->cpt = cpt; + return cpTable; } /* The MinFill criterion is a heuristic for variable elimination. @@ -83,17 +78,37 @@ namespace bayesnet { } return result; } - Factor* Node::toFactor() + void Node::computeCPT(map>& dataset, const int laplaceSmoothing) { - vector variables; - vector cardinalities; - variables.push_back(name); - cardinalities.push_back(numStates); - for (auto parent : parents) { - variables.push_back(parent->getName()); - cardinalities.push_back(parent->getNumStates()); + // Get dimensions of the CPT + dimensions.push_back(numStates); + for (auto father : getParents()) { + dimensions.push_back(father->getNumStates()); } - return new Factor(variables, cardinalities, cpt); - + auto length = dimensions.size(); + // Create a tensor of zeros with the dimensions of the CPT + cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; + // Fill table with counts + for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { + torch::List> coordinates; + coordinates.push_back(torch::tensor(dataset[name][n_sample])); + for (auto father : getParents()) { + coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample])); + } + // Increment the count of the corresponding coordinate + cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + 1); + } + // Normalize the counts + cpTable = cpTable / cpTable.sum(0); + } + float Node::getFactorValue(map& evidence) + { + torch::List> coordinates; + // following predetermined order of indices in the cpTable (see Node.h) + coordinates.push_back(torch::tensor(evidence[name])); + for (auto parent : getParents()) { + coordinates.push_back(torch::tensor(evidence[parent->getName()])); + } + return cpTable.index({ coordinates }).item(); } } \ No newline at end of file diff --git a/src/Node.h b/src/Node.h index 2dd2d1b..39189ce 100644 --- a/src/Node.h +++ b/src/Node.h @@ -1,21 +1,18 @@ #ifndef NODE_H #define NODE_H #include -#include "Factor.h" #include #include namespace bayesnet { using namespace std; class Node { private: - static int next_id; - const int id; string name; vector parents; vector children; - torch::Tensor cpTable; - int numStates; - torch::Tensor cpt; + int numStates; // number of states of the variable + torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ... + vector dimensions; // dimensions of the cpTable vector combinations(const set&); public: Node(const std::string&, int); @@ -27,12 +24,11 @@ namespace bayesnet { vector& getParents(); vector& getChildren(); torch::Tensor& getCPT(); - void setCPT(const torch::Tensor&); + void computeCPT(map>&, const int); int getNumStates() const; void setNumStates(int); unsigned minFill(); - int getId() const { return id; } - Factor* toFactor(); + float getFactorValue(map&); }; } #endif \ No newline at end of file