ComputeCPT Optimization

This commit is contained in:
2025-05-13 17:43:17 +02:00
parent b11620bbe8
commit 250036f224
3 changed files with 44 additions and 23 deletions

View File

@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix the vcpkg configuration in building the library.
- Fix the sample app to use the vcpkg configuration.
- Add predict_proba method to all Ld classifiers.
- Optimize the computeCPT method in the Node class with libtorch vectorized operations and remove the for loop.
## [1.1.0] - 2025-04-27

View File

@@ -45,6 +45,7 @@ namespace bayesnet {
}
torch::Tensor SPODELd::predict_proba(torch::Tensor& X)
{
std::cout << "Debug: SPODELd::predict_proba" << std::endl;
auto Xt = prepareX(X);
return SPODE::predict_proba(Xt);
}

View File

@@ -99,36 +99,55 @@ namespace bayesnet {
for (const auto& parent : parents) {
dimensions.push_back(parent->getNumStates());
}
//transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
// Create a tensor initialized with smoothing
cpTable = torch::full(dimensions, smoothing, torch::kDouble);
// Create a map for quick feature index lookup
std::unordered_map<std::string, int> featureIndexMap;
for (size_t i = 0; i < features.size(); ++i) {
featureIndexMap[features[i]] = i;
}
// Fill table with counts
// Get the index of this node's feature
int name_index = featureIndexMap[name];
// Get parent indices in dataset
std::vector<int> parent_indices;
parent_indices.reserve(parents.size());
for (const auto& parent : parents) {
parent_indices.push_back(featureIndexMap[parent->getName()]);
}
c10::List<c10::optional<at::Tensor>> coordinates;
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
coordinates.clear();
auto sample = dataset.index({ "...", n_sample });
coordinates.push_back(sample[name_index]);
for (size_t i = 0; i < parent_indices.size(); ++i) {
coordinates.push_back(sample[parent_indices[i]]);
std::unordered_map<std::string, int> cachedFeatureIndexMap;
bool featureIndexMapReady = false;
// Build featureIndexMap if not ready
if (!featureIndexMapReady) {
cachedFeatureIndexMap.clear();
for (size_t i = 0; i < features.size(); ++i) {
cachedFeatureIndexMap[features[i]] = i;
}
// Increment the count of the corresponding coordinate
cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
featureIndexMapReady = true;
}
const auto& featureIndexMap = cachedFeatureIndexMap;
// Gather indices for node and parents
std::vector<int64_t> all_indices;
all_indices.push_back(featureIndexMap.at(name));
for (const auto& parent : parents) {
all_indices.push_back(featureIndexMap.at(parent->getName()));
}
// Extract relevant columns: shape (num_features, num_samples)
auto indices_tensor = dataset.index_select(0, torch::tensor(all_indices, torch::kLong));
// Transpose to (num_samples, num_features)
indices_tensor = indices_tensor.transpose(0, 1).to(torch::kLong);
// Flatten CPT for easier indexing
auto flat_cpt = cpTable.flatten();
// Compute strides for flattening multi-dim indices
std::vector<int64_t> strides(all_indices.size(), 1);
for (int i = strides.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * cpTable.size(i + 1);
}
// Compute flat indices for each sample
auto indices_tensor_cpu = indices_tensor.cpu();
auto indices_accessor = indices_tensor_cpu.accessor<int64_t, 2>();
std::vector<int64_t> flat_indices(indices_tensor.size(0));
for (int64_t i = 0; i < indices_tensor.size(0); ++i) {
int64_t idx = 0;
for (size_t j = 0; j < strides.size(); ++j) {
idx += indices_accessor[i][j] * strides[j];
}
flat_indices[i] = idx;
}
// Accumulate weights into flat CPT
auto flat_indices_tensor = torch::from_blob(flat_indices.data(), { (int64_t)flat_indices.size() }, torch::kLong).clone();
flat_cpt.index_add_(0, flat_indices_tensor, weights.cpu());
cpTable = flat_cpt.view(cpTable.sizes());
// Normalize the counts (dividing each row by the sum of the row)
cpTable /= cpTable.sum(0, true);
return;
}
double Node::getFactorValue(std::map<std::string, int>& evidence)
{