Optimize ComputeCPT method with a approx. 30% reducing time
This commit is contained in:
@@ -23,6 +23,7 @@ namespace bayesnet {
|
||||
throw std::invalid_argument("y must be an integer tensor");
|
||||
}
|
||||
}
|
||||
// Fit method for single classifier
|
||||
map<std::string, std::vector<int>> Proposal::localDiscretizationProposal(const map<std::string, std::vector<int>>& oldStates, Network& model)
|
||||
{
|
||||
// order of local discretization is important. no good 0, 1, 2...
|
||||
|
@@ -45,7 +45,6 @@ namespace bayesnet {
|
||||
}
|
||||
torch::Tensor SPODELd::predict_proba(torch::Tensor& X)
|
||||
{
|
||||
std::cout << "Debug: SPODELd::predict_proba" << std::endl;
|
||||
auto Xt = prepareX(X);
|
||||
return SPODE::predict_proba(Xt);
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
// ***************************************************************
|
||||
|
||||
#include "Node.h"
|
||||
#include <iterator>
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
@@ -94,43 +95,34 @@ namespace bayesnet {
|
||||
{
|
||||
dimensions.clear();
|
||||
dimensions.reserve(parents.size() + 1);
|
||||
// Get dimensions of the CPT
|
||||
dimensions.push_back(numStates);
|
||||
for (const auto& parent : parents) {
|
||||
dimensions.push_back(parent->getNumStates());
|
||||
}
|
||||
// Create a tensor initialized with smoothing
|
||||
cpTable = torch::full(dimensions, smoothing, torch::kDouble);
|
||||
// Create a map for quick feature index lookup
|
||||
std::unordered_map<std::string, int> cachedFeatureIndexMap;
|
||||
bool featureIndexMapReady = false;
|
||||
// Build featureIndexMap if not ready
|
||||
if (!featureIndexMapReady) {
|
||||
cachedFeatureIndexMap.clear();
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
cachedFeatureIndexMap[features[i]] = i;
|
||||
}
|
||||
featureIndexMapReady = true;
|
||||
|
||||
// Build feature index map
|
||||
std::unordered_map<std::string, int> featureIndexMap;
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
featureIndexMap[features[i]] = i;
|
||||
}
|
||||
const auto& featureIndexMap = cachedFeatureIndexMap;
|
||||
|
||||
// Gather indices for node and parents
|
||||
std::vector<int64_t> all_indices;
|
||||
all_indices.push_back(featureIndexMap.at(name));
|
||||
all_indices.push_back(featureIndexMap[name]);
|
||||
for (const auto& parent : parents) {
|
||||
all_indices.push_back(featureIndexMap.at(parent->getName()));
|
||||
all_indices.push_back(featureIndexMap[parent->getName()]);
|
||||
}
|
||||
|
||||
// Extract relevant columns: shape (num_features, num_samples)
|
||||
auto indices_tensor = dataset.index_select(0, torch::tensor(all_indices, torch::kLong));
|
||||
// Transpose to (num_samples, num_features)
|
||||
indices_tensor = indices_tensor.transpose(0, 1).to(torch::kLong);
|
||||
// Flatten CPT for easier indexing
|
||||
auto flat_cpt = cpTable.flatten();
|
||||
// Compute strides for flattening multi-dim indices
|
||||
indices_tensor = indices_tensor.transpose(0, 1).to(torch::kLong); // (num_samples, num_features)
|
||||
|
||||
// Manual flattening of indices
|
||||
std::vector<int64_t> strides(all_indices.size(), 1);
|
||||
for (int i = strides.size() - 2; i >= 0; --i) {
|
||||
strides[i] = strides[i + 1] * cpTable.size(i + 1);
|
||||
}
|
||||
// Compute flat indices for each sample
|
||||
auto indices_tensor_cpu = indices_tensor.cpu();
|
||||
auto indices_accessor = indices_tensor_cpu.accessor<int64_t, 2>();
|
||||
std::vector<int64_t> flat_indices(indices_tensor.size(0));
|
||||
@@ -141,13 +133,15 @@ namespace bayesnet {
|
||||
}
|
||||
flat_indices[i] = idx;
|
||||
}
|
||||
|
||||
// Accumulate weights into flat CPT
|
||||
auto flat_cpt = cpTable.flatten();
|
||||
auto flat_indices_tensor = torch::from_blob(flat_indices.data(), { (int64_t)flat_indices.size() }, torch::kLong).clone();
|
||||
flat_cpt.index_add_(0, flat_indices_tensor, weights.cpu());
|
||||
cpTable = flat_cpt.view(cpTable.sizes());
|
||||
|
||||
// Normalize the counts (dividing each row by the sum of the row)
|
||||
cpTable /= cpTable.sum(0, true);
|
||||
return;
|
||||
}
|
||||
double Node::getFactorValue(std::map<std::string, int>& evidence)
|
||||
{
|
||||
|
Reference in New Issue
Block a user