From 0bbc8328a9688a64a0c3986c707e83d83587d454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 8 Jul 2024 13:27:55 +0200 Subject: [PATCH] Change cpt table type to float --- bayesnet/network/Node.cc | 8 ++++---- bayesnet/network/Node.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bayesnet/network/Node.cc b/bayesnet/network/Node.cc index 44fc900..b62e275 100644 --- a/bayesnet/network/Node.cc +++ b/bayesnet/network/Node.cc @@ -97,7 +97,7 @@ namespace bayesnet { dimensions.push_back(numStates); transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); }); // Create a tensor of zeros with the dimensions of the CPT - cpTable = torch::zeros(dimensions, torch::kFloat) + smoothing; + cpTable = torch::zeros(dimensions, torch::kDouble) + smoothing; // Fill table with counts auto pos = find(features.begin(), features.end(), name); if (pos == features.end()) { @@ -118,19 +118,19 @@ namespace bayesnet { coordinates.push_back(sample[parent_index]); } // Increment the count of the corresponding coordinate - cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item()); + cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true); } // Normalize the counts // Divide each row by the sum of the row cpTable = cpTable / cpTable.sum(0); } - float Node::getFactorValue(std::map& evidence) + double Node::getFactorValue(std::map& evidence) { c10::List> coordinates; // following predetermined order of indices in the cpTable (see Node.h) coordinates.push_back(at::tensor(evidence[name])); transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); }); - return cpTable.index({ coordinates }).item(); + return cpTable.index({ coordinates }).item(); } std::vector Node::graph(const std::string& className) { diff --git a/bayesnet/network/Node.h b/bayesnet/network/Node.h index dc21119..b950d70 100644 --- a/bayesnet/network/Node.h +++ b/bayesnet/network/Node.h @@ -28,7 +28,7 @@ namespace bayesnet { void setNumStates(int); unsigned minFill(); std::vector graph(const std::string& clasName); // Returns a std::vector of std::strings representing the graph in graphviz format - float getFactorValue(std::map&); + double getFactorValue(std::map&); private: std::string name; std::vector parents;