ComputeCPT Optimization

This commit is contained in:
2025-02-13 01:17:37 +01:00
parent f658149977
commit dd98cf159d
3 changed files with 40 additions and 20 deletions

View File

@@ -93,36 +93,42 @@ namespace bayesnet {
void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double smoothing, const torch::Tensor& weights)
{
dimensions.clear();
dimensions.reserve(parents.size() + 1);
// Get dimensions of the CPT
dimensions.push_back(numStates);
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
// Create a tensor of zeros with the dimensions of the CPT
cpTable = torch::zeros(dimensions, torch::kDouble) + smoothing;
// Fill table with counts
auto pos = find(features.begin(), features.end(), name);
if (pos == features.end()) {
throw std::logic_error("Feature " + name + " not found in dataset");
for (const auto& parent : parents) {
dimensions.push_back(parent->getNumStates());
}
//transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
// Create a tensor initialized with smoothing
cpTable = torch::full(dimensions, smoothing, torch::kDouble);
// Create a map for quick feature index lookup
std::unordered_map<std::string, int> featureIndexMap;
for (size_t i = 0; i < features.size(); ++i) {
featureIndexMap[features[i]] = i;
}
// Fill table with counts
// Get the index of this node's feature
int name_index = featureIndexMap[name];
// Get parent indices in dataset
std::vector<int> parent_indices;
parent_indices.reserve(parents.size());
for (const auto& parent : parents) {
parent_indices.push_back(featureIndexMap[parent->getName()]);
}
int name_index = pos - features.begin();
c10::List<c10::optional<at::Tensor>> coordinates;
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
coordinates.clear();
auto sample = dataset.index({ "...", n_sample });
coordinates.push_back(sample[name_index]);
for (auto parent : parents) {
pos = find(features.begin(), features.end(), parent->getName());
if (pos == features.end()) {
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
}
int parent_index = pos - features.begin();
coordinates.push_back(sample[parent_index]);
for (size_t i = 0; i < parent_indices.size(); ++i) {
coordinates.push_back(sample[parent_indices[i]]);
}
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
}
// Normalize the counts
// Divide each row by the sum of the row
cpTable = cpTable / cpTable.sum(0);
// Normalize the counts (dividing each row by the sum of the row)
cpTable /= cpTable.sum(0, true);
}
double Node::getFactorValue(std::map<std::string, int>& evidence)
{