Complete estimateParameters

checked with pgmpy
This commit is contained in:
Ricardo Montañana Gómez 2023-07-01 01:44:56 +02:00
parent 71d730d228
commit 52cb85b41b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 49 additions and 144 deletions

View File

@ -99,135 +99,33 @@ namespace bayesnet {
estimateParameters();
}
// void Network::estimateParameters()
// {
// auto dimensions = vector<int64_t>();
// for (auto [name, node] : nodes) {
// // Get dimensions of the CPT
// dimensions.clear();
// dimensions.push_back(node->getNumStates());
// for (auto father : node->getParents()) {
// dimensions.push_back(father->getNumStates());
// }
// auto length = dimensions.size();
// // Create a tensor of zeros with the dimensions of the CPT
// torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat);
// // Fill table with counts
// for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
// torch::List<c10::optional<torch::Tensor>> coordinates;
// coordinates.push_back(torch::tensor(dataset[name][n_sample]));
// for (auto father : node->getParents()) {
// coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
// }
// // Increment the count of the corresponding coordinate
// cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
// }
// // store thre resulting cpt in the node
// node->setCPT(cpt);
// }
// }
// void Network::estimateParameters()
// {
// // Lambda function to compute joint counts of states
// auto jointCounts = [this](const vector<string>& nodeNames) {
// int size = nodeNames.size();
// std::vector<int64_t> sizes(size);
// for (int i = 0; i < size; ++i) {
// sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
// }
// torch::Tensor counts = torch::zeros(sizes, torch::kLong);
// int dataSize = this->dataset[nodeNames[0]].size();
// for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
// std::vector<torch::Tensor> idx(size);
// for (int i = 0; i < size; ++i) {
// idx[i] = torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong);
// }
// torch::Tensor indices = torch::stack(idx);
// counts.index_put_({ indices }, counts.index({ indices }) + 1);
// }
// return counts;
// };
// // Lambda function to compute marginal counts of states
// auto marginalCounts = [](const torch::Tensor& jointCounts) {
// return jointCounts.sum(-1);
// };
// for (auto& pair : nodes) {
// Node* node = pair.second;
// // Create a list of names of the node and its parents
// std::vector<string> nodeNames;
// nodeNames.push_back(node->getName());
// for (Node* parent : node->getParents()) {
// nodeNames.push_back(parent->getName());
// }
// // Compute counts and normalize to get probabilities
// torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
// torch::Tensor parentCounts = marginalCounts(counts);
// parentCounts = parentCounts.unsqueeze(-1);
// // The CPT is represented as a tensor and stored in the Node
// node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
// }
// }
void Network::estimateParameters()
{
// Lambda function to compute joint counts of states
auto jointCounts = [this](const vector<string>& nodeNames) {
int size = nodeNames.size();
std::vector<int64_t> sizes(size);
for (int i = 0; i < size; ++i) {
sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
auto dimensions = vector<int64_t>();
for (auto [name, node] : nodes) {
// Get dimensions of the CPT
dimensions.clear();
dimensions.push_back(node->getNumStates());
for (auto father : node->getParents()) {
dimensions.push_back(father->getNumStates());
}
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
int dataSize = this->dataset[nodeNames[0]].size();
torch::List<c10::optional<torch::Tensor>> indices;
for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
indices.clear();
for (int i = 0; i < size; ++i) {
indices.push_back(torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong));
auto length = dimensions.size();
// Create a tensor of zeros with the dimensions of the CPT
torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
// Fill table with counts
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
torch::List<c10::optional<torch::Tensor>> coordinates;
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
for (auto father : node->getParents()) {
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
}
//torch::Tensor indices = torch::stack(idx);
counts.index_put_({ indices }, counts.index({ indices }) + 1);
// Increment the count of the corresponding coordinate
cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
}
return counts;
};
// Lambda function to compute marginal counts of states
auto marginalCounts = [](const torch::Tensor& jointCounts) {
return jointCounts.sum(-1);
};
for (auto& pair : nodes) {
Node* node = pair.second;
// Create a list of names of the node and its parents
std::vector<string> nodeNames;
nodeNames.push_back(node->getName());
for (Node* parent : node->getParents()) {
nodeNames.push_back(parent->getName());
}
// Compute counts and normalize to get probabilities
torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
torch::Tensor parentCounts = marginalCounts(counts);
parentCounts = parentCounts.unsqueeze(-1);
// The CPT is represented as a tensor and stored in the Node
node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
// Normalize the counts
cpt = cpt / cpt.sum(0);
// store thre resulting cpt in the node
node->setCPT(cpt);
}
}
}

47
test.cc
View File

@ -24,30 +24,37 @@
int main()
{
torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
int i = 3, j = 1, k = 2; // Indices for the cell you want to update
//torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
//int i = 3, j = 1, k = 2; // Indices for the cell you want to update
// Print original tensor
torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example
std::cout << t << std::endl;
std::cout << "sum(0)" << std::endl;
std::cout << t.sum(0) << std::endl;
std::cout << "sum(1)" << std::endl;
std::cout << t.sum(1) << std::endl;
std::cout << "Normalized" << std::endl;
std::cout << t / t.sum(0) << std::endl;
// New value
torch::Tensor new_val = torch::tensor(10.0f);
// torch::Tensor new_val = torch::tensor(10.0f);
// Indices for the cell you want to update
std::vector<torch::Tensor> indices;
indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
//torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
// Update cell
//torch::Tensor result = torch::stack(indices);
//torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
torch::List<c10::optional<torch::Tensor>> indices_list;
indices_list.push_back(torch::tensor(i));
indices_list.push_back(torch::tensor(j));
indices_list.push_back(torch::tensor(k));
//t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
t.index_put_(indices_list, new_val);
// // Indices for the cell you want to update
// std::vector<torch::Tensor> indices;
// indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
// indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
// indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
// //torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
// // Update cell
// //torch::Tensor result = torch::stack(indices);
// //torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
// torch::List<c10::optional<torch::Tensor>> indices_list;
// indices_list.push_back(torch::tensor(i));
// indices_list.push_back(torch::tensor(j));
// indices_list.push_back(torch::tensor(k));
// //t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
// t.index_put_(indices_list, new_val);
// Print updated tensor
std::cout << t << std::endl;
// // Print updated tensor
// std::cout << t << std::endl;
}