Complete estimateParameters
checked with pgmpy
This commit is contained in:
parent
71d730d228
commit
52cb85b41b
146
Network.cc
146
Network.cc
@ -99,135 +99,33 @@ namespace bayesnet {
|
||||
estimateParameters();
|
||||
}
|
||||
|
||||
// void Network::estimateParameters()
|
||||
// {
|
||||
// auto dimensions = vector<int64_t>();
|
||||
// for (auto [name, node] : nodes) {
|
||||
// // Get dimensions of the CPT
|
||||
// dimensions.clear();
|
||||
// dimensions.push_back(node->getNumStates());
|
||||
// for (auto father : node->getParents()) {
|
||||
// dimensions.push_back(father->getNumStates());
|
||||
// }
|
||||
// auto length = dimensions.size();
|
||||
// // Create a tensor of zeros with the dimensions of the CPT
|
||||
// torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat);
|
||||
// // Fill table with counts
|
||||
// for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||
// torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||
// coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||
// for (auto father : node->getParents()) {
|
||||
// coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
||||
// }
|
||||
// // Increment the count of the corresponding coordinate
|
||||
// cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
|
||||
// }
|
||||
// // store thre resulting cpt in the node
|
||||
// node->setCPT(cpt);
|
||||
// }
|
||||
// }
|
||||
|
||||
// void Network::estimateParameters()
|
||||
// {
|
||||
// // Lambda function to compute joint counts of states
|
||||
// auto jointCounts = [this](const vector<string>& nodeNames) {
|
||||
// int size = nodeNames.size();
|
||||
// std::vector<int64_t> sizes(size);
|
||||
|
||||
// for (int i = 0; i < size; ++i) {
|
||||
// sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
|
||||
// }
|
||||
|
||||
// torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
||||
|
||||
// int dataSize = this->dataset[nodeNames[0]].size();
|
||||
|
||||
// for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
|
||||
// std::vector<torch::Tensor> idx(size);
|
||||
// for (int i = 0; i < size; ++i) {
|
||||
// idx[i] = torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong);
|
||||
// }
|
||||
// torch::Tensor indices = torch::stack(idx);
|
||||
// counts.index_put_({ indices }, counts.index({ indices }) + 1);
|
||||
// }
|
||||
|
||||
// return counts;
|
||||
// };
|
||||
|
||||
// // Lambda function to compute marginal counts of states
|
||||
// auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
||||
// return jointCounts.sum(-1);
|
||||
// };
|
||||
|
||||
// for (auto& pair : nodes) {
|
||||
// Node* node = pair.second;
|
||||
|
||||
// // Create a list of names of the node and its parents
|
||||
// std::vector<string> nodeNames;
|
||||
// nodeNames.push_back(node->getName());
|
||||
// for (Node* parent : node->getParents()) {
|
||||
// nodeNames.push_back(parent->getName());
|
||||
// }
|
||||
|
||||
// // Compute counts and normalize to get probabilities
|
||||
// torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
|
||||
// torch::Tensor parentCounts = marginalCounts(counts);
|
||||
// parentCounts = parentCounts.unsqueeze(-1);
|
||||
|
||||
// // The CPT is represented as a tensor and stored in the Node
|
||||
// node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
|
||||
// }
|
||||
// }
|
||||
void Network::estimateParameters()
|
||||
{
|
||||
// Lambda function to compute joint counts of states
|
||||
auto jointCounts = [this](const vector<string>& nodeNames) {
|
||||
int size = nodeNames.size();
|
||||
std::vector<int64_t> sizes(size);
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
|
||||
auto dimensions = vector<int64_t>();
|
||||
for (auto [name, node] : nodes) {
|
||||
// Get dimensions of the CPT
|
||||
dimensions.clear();
|
||||
dimensions.push_back(node->getNumStates());
|
||||
for (auto father : node->getParents()) {
|
||||
dimensions.push_back(father->getNumStates());
|
||||
}
|
||||
|
||||
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
||||
|
||||
int dataSize = this->dataset[nodeNames[0]].size();
|
||||
torch::List<c10::optional<torch::Tensor>> indices;
|
||||
for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
|
||||
indices.clear();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
indices.push_back(torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong));
|
||||
auto length = dimensions.size();
|
||||
// Create a tensor of zeros with the dimensions of the CPT
|
||||
torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
||||
// Fill table with counts
|
||||
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||
for (auto father : node->getParents()) {
|
||||
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
||||
}
|
||||
//torch::Tensor indices = torch::stack(idx);
|
||||
counts.index_put_({ indices }, counts.index({ indices }) + 1);
|
||||
// Increment the count of the corresponding coordinate
|
||||
cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
|
||||
}
|
||||
|
||||
return counts;
|
||||
};
|
||||
|
||||
// Lambda function to compute marginal counts of states
|
||||
auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
||||
return jointCounts.sum(-1);
|
||||
};
|
||||
|
||||
for (auto& pair : nodes) {
|
||||
Node* node = pair.second;
|
||||
|
||||
// Create a list of names of the node and its parents
|
||||
std::vector<string> nodeNames;
|
||||
nodeNames.push_back(node->getName());
|
||||
for (Node* parent : node->getParents()) {
|
||||
nodeNames.push_back(parent->getName());
|
||||
}
|
||||
|
||||
// Compute counts and normalize to get probabilities
|
||||
torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
|
||||
torch::Tensor parentCounts = marginalCounts(counts);
|
||||
parentCounts = parentCounts.unsqueeze(-1);
|
||||
|
||||
// The CPT is represented as a tensor and stored in the Node
|
||||
node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
|
||||
// Normalize the counts
|
||||
cpt = cpt / cpt.sum(0);
|
||||
// store thre resulting cpt in the node
|
||||
node->setCPT(cpt);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
47
test.cc
47
test.cc
@ -24,30 +24,37 @@
|
||||
|
||||
int main()
|
||||
{
|
||||
torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
|
||||
int i = 3, j = 1, k = 2; // Indices for the cell you want to update
|
||||
//torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
|
||||
//int i = 3, j = 1, k = 2; // Indices for the cell you want to update
|
||||
// Print original tensor
|
||||
torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example
|
||||
std::cout << t << std::endl;
|
||||
std::cout << "sum(0)" << std::endl;
|
||||
std::cout << t.sum(0) << std::endl;
|
||||
std::cout << "sum(1)" << std::endl;
|
||||
std::cout << t.sum(1) << std::endl;
|
||||
std::cout << "Normalized" << std::endl;
|
||||
std::cout << t / t.sum(0) << std::endl;
|
||||
|
||||
// New value
|
||||
torch::Tensor new_val = torch::tensor(10.0f);
|
||||
// torch::Tensor new_val = torch::tensor(10.0f);
|
||||
|
||||
// Indices for the cell you want to update
|
||||
std::vector<torch::Tensor> indices;
|
||||
indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
|
||||
indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
|
||||
indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
|
||||
//torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
|
||||
// Update cell
|
||||
//torch::Tensor result = torch::stack(indices);
|
||||
//torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
|
||||
torch::List<c10::optional<torch::Tensor>> indices_list;
|
||||
indices_list.push_back(torch::tensor(i));
|
||||
indices_list.push_back(torch::tensor(j));
|
||||
indices_list.push_back(torch::tensor(k));
|
||||
//t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
|
||||
t.index_put_(indices_list, new_val);
|
||||
// // Indices for the cell you want to update
|
||||
// std::vector<torch::Tensor> indices;
|
||||
// indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
|
||||
// indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
|
||||
// indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
|
||||
// //torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
|
||||
// // Update cell
|
||||
// //torch::Tensor result = torch::stack(indices);
|
||||
// //torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
|
||||
// torch::List<c10::optional<torch::Tensor>> indices_list;
|
||||
// indices_list.push_back(torch::tensor(i));
|
||||
// indices_list.push_back(torch::tensor(j));
|
||||
// indices_list.push_back(torch::tensor(k));
|
||||
// //t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
|
||||
// t.index_put_(indices_list, new_val);
|
||||
|
||||
// Print updated tensor
|
||||
std::cout << t << std::endl;
|
||||
// // Print updated tensor
|
||||
// std::cout << t << std::endl;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user