Change cpt table type to float
This commit is contained in:
parent
35ca862eca
commit
0bbc8328a9
@ -97,7 +97,7 @@ namespace bayesnet {
|
||||
dimensions.push_back(numStates);
|
||||
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
cpTable = torch::zeros(dimensions, torch::kFloat) + smoothing;
|
||||
cpTable = torch::zeros(dimensions, torch::kDouble) + smoothing;
|
||||
// Fill table with counts
|
||||
auto pos = find(features.begin(), features.end(), name);
|
||||
if (pos == features.end()) {
|
||||
@ -118,19 +118,19 @@ namespace bayesnet {
|
||||
coordinates.push_back(sample[parent_index]);
|
||||
}
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
|
||||
cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
|
||||
}
|
||||
// Normalize the counts
|
||||
// Divide each row by the sum of the row
|
||||
cpTable = cpTable / cpTable.sum(0);
|
||||
}
|
||||
float Node::getFactorValue(std::map<std::string, int>& evidence)
|
||||
double Node::getFactorValue(std::map<std::string, int>& evidence)
|
||||
{
|
||||
c10::List<c10::optional<at::Tensor>> coordinates;
|
||||
// following predetermined order of indices in the cpTable (see Node.h)
|
||||
coordinates.push_back(at::tensor(evidence[name]));
|
||||
transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
|
||||
return cpTable.index({ coordinates }).item<float>();
|
||||
return cpTable.index({ coordinates }).item<double>();
|
||||
}
|
||||
std::vector<std::string> Node::graph(const std::string& className)
|
||||
{
|
||||
|
@ -28,7 +28,7 @@ namespace bayesnet {
|
||||
void setNumStates(int);
|
||||
unsigned minFill();
|
||||
std::vector<std::string> graph(const std::string& clasName); // Returns a std::vector of std::strings representing the graph in graphviz format
|
||||
float getFactorValue(std::map<std::string, int>&);
|
||||
double getFactorValue(std::map<std::string, int>&);
|
||||
private:
|
||||
std::string name;
|
||||
std::vector<Node*> parents;
|
||||
|
Loading…
Reference in New Issue
Block a user