Complete estimateParameters
checked with pgmpy
This commit is contained in:
parent
71d730d228
commit
52cb85b41b
146
Network.cc
146
Network.cc
@ -99,135 +99,33 @@ namespace bayesnet {
|
|||||||
estimateParameters();
|
estimateParameters();
|
||||||
}
|
}
|
||||||
|
|
||||||
// void Network::estimateParameters()
|
|
||||||
// {
|
|
||||||
// auto dimensions = vector<int64_t>();
|
|
||||||
// for (auto [name, node] : nodes) {
|
|
||||||
// // Get dimensions of the CPT
|
|
||||||
// dimensions.clear();
|
|
||||||
// dimensions.push_back(node->getNumStates());
|
|
||||||
// for (auto father : node->getParents()) {
|
|
||||||
// dimensions.push_back(father->getNumStates());
|
|
||||||
// }
|
|
||||||
// auto length = dimensions.size();
|
|
||||||
// // Create a tensor of zeros with the dimensions of the CPT
|
|
||||||
// torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat);
|
|
||||||
// // Fill table with counts
|
|
||||||
// for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
|
||||||
// torch::List<c10::optional<torch::Tensor>> coordinates;
|
|
||||||
// coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
|
||||||
// for (auto father : node->getParents()) {
|
|
||||||
// coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
|
||||||
// }
|
|
||||||
// // Increment the count of the corresponding coordinate
|
|
||||||
// cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
|
|
||||||
// }
|
|
||||||
// // store thre resulting cpt in the node
|
|
||||||
// node->setCPT(cpt);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void Network::estimateParameters()
|
|
||||||
// {
|
|
||||||
// // Lambda function to compute joint counts of states
|
|
||||||
// auto jointCounts = [this](const vector<string>& nodeNames) {
|
|
||||||
// int size = nodeNames.size();
|
|
||||||
// std::vector<int64_t> sizes(size);
|
|
||||||
|
|
||||||
// for (int i = 0; i < size; ++i) {
|
|
||||||
// sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
|
||||||
|
|
||||||
// int dataSize = this->dataset[nodeNames[0]].size();
|
|
||||||
|
|
||||||
// for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
|
|
||||||
// std::vector<torch::Tensor> idx(size);
|
|
||||||
// for (int i = 0; i < size; ++i) {
|
|
||||||
// idx[i] = torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong);
|
|
||||||
// }
|
|
||||||
// torch::Tensor indices = torch::stack(idx);
|
|
||||||
// counts.index_put_({ indices }, counts.index({ indices }) + 1);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return counts;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// // Lambda function to compute marginal counts of states
|
|
||||||
// auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
|
||||||
// return jointCounts.sum(-1);
|
|
||||||
// };
|
|
||||||
|
|
||||||
// for (auto& pair : nodes) {
|
|
||||||
// Node* node = pair.second;
|
|
||||||
|
|
||||||
// // Create a list of names of the node and its parents
|
|
||||||
// std::vector<string> nodeNames;
|
|
||||||
// nodeNames.push_back(node->getName());
|
|
||||||
// for (Node* parent : node->getParents()) {
|
|
||||||
// nodeNames.push_back(parent->getName());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Compute counts and normalize to get probabilities
|
|
||||||
// torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
|
|
||||||
// torch::Tensor parentCounts = marginalCounts(counts);
|
|
||||||
// parentCounts = parentCounts.unsqueeze(-1);
|
|
||||||
|
|
||||||
// // The CPT is represented as a tensor and stored in the Node
|
|
||||||
// node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
void Network::estimateParameters()
|
void Network::estimateParameters()
|
||||||
{
|
{
|
||||||
// Lambda function to compute joint counts of states
|
auto dimensions = vector<int64_t>();
|
||||||
auto jointCounts = [this](const vector<string>& nodeNames) {
|
for (auto [name, node] : nodes) {
|
||||||
int size = nodeNames.size();
|
// Get dimensions of the CPT
|
||||||
std::vector<int64_t> sizes(size);
|
dimensions.clear();
|
||||||
|
dimensions.push_back(node->getNumStates());
|
||||||
for (int i = 0; i < size; ++i) {
|
for (auto father : node->getParents()) {
|
||||||
sizes[i] = this->nodes[nodeNames[i]]->getNumStates();
|
dimensions.push_back(father->getNumStates());
|
||||||
}
|
}
|
||||||
|
auto length = dimensions.size();
|
||||||
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
// Create a tensor of zeros with the dimensions of the CPT
|
||||||
|
torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
||||||
int dataSize = this->dataset[nodeNames[0]].size();
|
// Fill table with counts
|
||||||
torch::List<c10::optional<torch::Tensor>> indices;
|
for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) {
|
||||||
for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) {
|
torch::List<c10::optional<torch::Tensor>> coordinates;
|
||||||
indices.clear();
|
coordinates.push_back(torch::tensor(dataset[name][n_sample]));
|
||||||
for (int i = 0; i < size; ++i) {
|
for (auto father : node->getParents()) {
|
||||||
indices.push_back(torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong));
|
coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample]));
|
||||||
}
|
}
|
||||||
//torch::Tensor indices = torch::stack(idx);
|
// Increment the count of the corresponding coordinate
|
||||||
counts.index_put_({ indices }, counts.index({ indices }) + 1);
|
cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1);
|
||||||
}
|
}
|
||||||
|
// Normalize the counts
|
||||||
return counts;
|
cpt = cpt / cpt.sum(0);
|
||||||
};
|
// store thre resulting cpt in the node
|
||||||
|
node->setCPT(cpt);
|
||||||
// Lambda function to compute marginal counts of states
|
|
||||||
auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
|
||||||
return jointCounts.sum(-1);
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto& pair : nodes) {
|
|
||||||
Node* node = pair.second;
|
|
||||||
|
|
||||||
// Create a list of names of the node and its parents
|
|
||||||
std::vector<string> nodeNames;
|
|
||||||
nodeNames.push_back(node->getName());
|
|
||||||
for (Node* parent : node->getParents()) {
|
|
||||||
nodeNames.push_back(parent->getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute counts and normalize to get probabilities
|
|
||||||
torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing;
|
|
||||||
torch::Tensor parentCounts = marginalCounts(counts);
|
|
||||||
parentCounts = parentCounts.unsqueeze(-1);
|
|
||||||
|
|
||||||
// The CPT is represented as a tensor and stored in the Node
|
|
||||||
node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble)));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
47
test.cc
47
test.cc
@ -24,30 +24,37 @@
|
|||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
|
//torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example
|
||||||
int i = 3, j = 1, k = 2; // Indices for the cell you want to update
|
//int i = 3, j = 1, k = 2; // Indices for the cell you want to update
|
||||||
// Print original tensor
|
// Print original tensor
|
||||||
|
torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example
|
||||||
std::cout << t << std::endl;
|
std::cout << t << std::endl;
|
||||||
|
std::cout << "sum(0)" << std::endl;
|
||||||
|
std::cout << t.sum(0) << std::endl;
|
||||||
|
std::cout << "sum(1)" << std::endl;
|
||||||
|
std::cout << t.sum(1) << std::endl;
|
||||||
|
std::cout << "Normalized" << std::endl;
|
||||||
|
std::cout << t / t.sum(0) << std::endl;
|
||||||
|
|
||||||
// New value
|
// New value
|
||||||
torch::Tensor new_val = torch::tensor(10.0f);
|
// torch::Tensor new_val = torch::tensor(10.0f);
|
||||||
|
|
||||||
// Indices for the cell you want to update
|
// // Indices for the cell you want to update
|
||||||
std::vector<torch::Tensor> indices;
|
// std::vector<torch::Tensor> indices;
|
||||||
indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
|
// indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension
|
||||||
indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
|
// indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension
|
||||||
indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
|
// indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension
|
||||||
//torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
|
// //torch::ArrayRef<at::indexing::TensorIndex> indices_ref(indices);
|
||||||
// Update cell
|
// // Update cell
|
||||||
//torch::Tensor result = torch::stack(indices);
|
// //torch::Tensor result = torch::stack(indices);
|
||||||
//torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
|
// //torch::List<c10::optional<torch::Tensor>> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) };
|
||||||
torch::List<c10::optional<torch::Tensor>> indices_list;
|
// torch::List<c10::optional<torch::Tensor>> indices_list;
|
||||||
indices_list.push_back(torch::tensor(i));
|
// indices_list.push_back(torch::tensor(i));
|
||||||
indices_list.push_back(torch::tensor(j));
|
// indices_list.push_back(torch::tensor(j));
|
||||||
indices_list.push_back(torch::tensor(k));
|
// indices_list.push_back(torch::tensor(k));
|
||||||
//t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
|
// //t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val);
|
||||||
t.index_put_(indices_list, new_val);
|
// t.index_put_(indices_list, new_val);
|
||||||
|
|
||||||
// Print updated tensor
|
// // Print updated tensor
|
||||||
std::cout << t << std::endl;
|
// std::cout << t << std::endl;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user