ComputeCPT Optimization
This commit is contained in:
@@ -93,36 +93,42 @@ namespace bayesnet {
|
||||
void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double smoothing, const torch::Tensor& weights)
|
||||
{
|
||||
dimensions.clear();
|
||||
dimensions.reserve(parents.size() + 1);
|
||||
// Get dimensions of the CPT
|
||||
dimensions.push_back(numStates);
|
||||
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
cpTable = torch::zeros(dimensions, torch::kDouble) + smoothing;
|
||||
// Fill table with counts
|
||||
auto pos = find(features.begin(), features.end(), name);
|
||||
if (pos == features.end()) {
|
||||
throw std::logic_error("Feature " + name + " not found in dataset");
|
||||
for (const auto& parent : parents) {
|
||||
dimensions.push_back(parent->getNumStates());
|
||||
}
|
||||
//transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
||||
// Create a tensor initialized with smoothing
|
||||
cpTable = torch::full(dimensions, smoothing, torch::kDouble);
|
||||
// Create a map for quick feature index lookup
|
||||
std::unordered_map<std::string, int> featureIndexMap;
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
featureIndexMap[features[i]] = i;
|
||||
}
|
||||
// Fill table with counts
|
||||
// Get the index of this node's feature
|
||||
int name_index = featureIndexMap[name];
|
||||
// Get parent indices in dataset
|
||||
std::vector<int> parent_indices;
|
||||
parent_indices.reserve(parents.size());
|
||||
for (const auto& parent : parents) {
|
||||
parent_indices.push_back(featureIndexMap[parent->getName()]);
|
||||
}
|
||||
int name_index = pos - features.begin();
|
||||
c10::List<c10::optional<at::Tensor>> coordinates;
|
||||
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
|
||||
coordinates.clear();
|
||||
auto sample = dataset.index({ "...", n_sample });
|
||||
coordinates.push_back(sample[name_index]);
|
||||
for (auto parent : parents) {
|
||||
pos = find(features.begin(), features.end(), parent->getName());
|
||||
if (pos == features.end()) {
|
||||
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
|
||||
}
|
||||
int parent_index = pos - features.begin();
|
||||
coordinates.push_back(sample[parent_index]);
|
||||
for (size_t i = 0; i < parent_indices.size(); ++i) {
|
||||
coordinates.push_back(sample[parent_indices[i]]);
|
||||
}
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
|
||||
}
|
||||
// Normalize the counts
|
||||
// Divide each row by the sum of the row
|
||||
cpTable = cpTable / cpTable.sum(0);
|
||||
// Normalize the counts (dividing each row by the sum of the row)
|
||||
cpTable /= cpTable.sum(0, true);
|
||||
}
|
||||
double Node::getFactorValue(std::map<std::string, int>& evidence)
|
||||
{
|
||||
|
Reference in New Issue
Block a user