From 250036f2241894ead55d5767c6bade56c1aac7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 13 May 2025 17:43:17 +0200 Subject: [PATCH] ComputeCPT Optimization --- CHANGELOG.md | 1 + bayesnet/classifiers/SPODELd.cc | 1 + bayesnet/network/Node.cc | 65 +++++++++++++++++++++------------ 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0077223..24e4b89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix the vcpkg configuration in building the library. - Fix the sample app to use the vcpkg configuration. - Add predict_proba method to all Ld classifiers. +- Optimize the computeCPT method in the Node class with libtorch vectorized operations and remove the for loop. ## [1.1.0] - 2025-04-27 diff --git a/bayesnet/classifiers/SPODELd.cc b/bayesnet/classifiers/SPODELd.cc index c68b7d9..a912261 100644 --- a/bayesnet/classifiers/SPODELd.cc +++ b/bayesnet/classifiers/SPODELd.cc @@ -45,6 +45,7 @@ namespace bayesnet { } torch::Tensor SPODELd::predict_proba(torch::Tensor& X) { + std::cout << "Debug: SPODELd::predict_proba" << std::endl; auto Xt = prepareX(X); return SPODE::predict_proba(Xt); } diff --git a/bayesnet/network/Node.cc b/bayesnet/network/Node.cc index 1b2381f..a66fc8a 100644 --- a/bayesnet/network/Node.cc +++ b/bayesnet/network/Node.cc @@ -99,36 +99,55 @@ namespace bayesnet { for (const auto& parent : parents) { dimensions.push_back(parent->getNumStates()); } - //transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); }); // Create a tensor initialized with smoothing cpTable = torch::full(dimensions, smoothing, torch::kDouble); // Create a map for quick feature index lookup - std::unordered_map featureIndexMap; - for (size_t i = 0; i < features.size(); ++i) { - featureIndexMap[features[i]] = i; - } - // Fill table with counts - // Get the index of this node's feature - int name_index = featureIndexMap[name]; - // Get parent indices in dataset - std::vector parent_indices; - parent_indices.reserve(parents.size()); - for (const auto& parent : parents) { - parent_indices.push_back(featureIndexMap[parent->getName()]); - } - c10::List> coordinates; - for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) { - coordinates.clear(); - auto sample = dataset.index({ "...", n_sample }); - coordinates.push_back(sample[name_index]); - for (size_t i = 0; i < parent_indices.size(); ++i) { - coordinates.push_back(sample[parent_indices[i]]); + std::unordered_map cachedFeatureIndexMap; + bool featureIndexMapReady = false; + // Build featureIndexMap if not ready + if (!featureIndexMapReady) { + cachedFeatureIndexMap.clear(); + for (size_t i = 0; i < features.size(); ++i) { + cachedFeatureIndexMap[features[i]] = i; } - // Increment the count of the corresponding coordinate - cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true); + featureIndexMapReady = true; } + const auto& featureIndexMap = cachedFeatureIndexMap; + // Gather indices for node and parents + std::vector all_indices; + all_indices.push_back(featureIndexMap.at(name)); + for (const auto& parent : parents) { + all_indices.push_back(featureIndexMap.at(parent->getName())); + } + // Extract relevant columns: shape (num_features, num_samples) + auto indices_tensor = dataset.index_select(0, torch::tensor(all_indices, torch::kLong)); + // Transpose to (num_samples, num_features) + indices_tensor = indices_tensor.transpose(0, 1).to(torch::kLong); + // Flatten CPT for easier indexing + auto flat_cpt = cpTable.flatten(); + // Compute strides for flattening multi-dim indices + std::vector strides(all_indices.size(), 1); + for (int i = strides.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * cpTable.size(i + 1); + } + // Compute flat indices for each sample + auto indices_tensor_cpu = indices_tensor.cpu(); + auto indices_accessor = indices_tensor_cpu.accessor(); + std::vector flat_indices(indices_tensor.size(0)); + for (int64_t i = 0; i < indices_tensor.size(0); ++i) { + int64_t idx = 0; + for (size_t j = 0; j < strides.size(); ++j) { + idx += indices_accessor[i][j] * strides[j]; + } + flat_indices[i] = idx; + } + // Accumulate weights into flat CPT + auto flat_indices_tensor = torch::from_blob(flat_indices.data(), { (int64_t)flat_indices.size() }, torch::kLong).clone(); + flat_cpt.index_add_(0, flat_indices_tensor, weights.cpu()); + cpTable = flat_cpt.view(cpTable.sizes()); // Normalize the counts (dividing each row by the sum of the row) cpTable /= cpTable.sum(0, true); + return; } double Node::getFactorValue(std::map& evidence) {