diff --git a/Network.cc b/Network.cc index 8e650f4..1c18fbb 100644 --- a/Network.cc +++ b/Network.cc @@ -99,135 +99,33 @@ namespace bayesnet { estimateParameters(); } - // void Network::estimateParameters() - // { - // auto dimensions = vector(); - // for (auto [name, node] : nodes) { - // // Get dimensions of the CPT - // dimensions.clear(); - // dimensions.push_back(node->getNumStates()); - // for (auto father : node->getParents()) { - // dimensions.push_back(father->getNumStates()); - // } - // auto length = dimensions.size(); - // // Create a tensor of zeros with the dimensions of the CPT - // torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat); - // // Fill table with counts - // for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { - // torch::List> coordinates; - // coordinates.push_back(torch::tensor(dataset[name][n_sample])); - // for (auto father : node->getParents()) { - // coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample])); - // } - // // Increment the count of the corresponding coordinate - // cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1); - // } - // // store thre resulting cpt in the node - // node->setCPT(cpt); - // } - // } - - // void Network::estimateParameters() - // { - // // Lambda function to compute joint counts of states - // auto jointCounts = [this](const vector& nodeNames) { - // int size = nodeNames.size(); - // std::vector sizes(size); - - // for (int i = 0; i < size; ++i) { - // sizes[i] = this->nodes[nodeNames[i]]->getNumStates(); - // } - - // torch::Tensor counts = torch::zeros(sizes, torch::kLong); - - // int dataSize = this->dataset[nodeNames[0]].size(); - - // for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) { - // std::vector idx(size); - // for (int i = 0; i < size; ++i) { - // idx[i] = torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong); - // } - // torch::Tensor indices = torch::stack(idx); - // counts.index_put_({ indices }, counts.index({ indices }) + 1); - // } - - // return counts; - // }; - - // // Lambda function to compute marginal counts of states - // auto marginalCounts = [](const torch::Tensor& jointCounts) { - // return jointCounts.sum(-1); - // }; - - // for (auto& pair : nodes) { - // Node* node = pair.second; - - // // Create a list of names of the node and its parents - // std::vector nodeNames; - // nodeNames.push_back(node->getName()); - // for (Node* parent : node->getParents()) { - // nodeNames.push_back(parent->getName()); - // } - - // // Compute counts and normalize to get probabilities - // torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing; - // torch::Tensor parentCounts = marginalCounts(counts); - // parentCounts = parentCounts.unsqueeze(-1); - - // // The CPT is represented as a tensor and stored in the Node - // node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble))); - // } - // } void Network::estimateParameters() { - // Lambda function to compute joint counts of states - auto jointCounts = [this](const vector& nodeNames) { - int size = nodeNames.size(); - std::vector sizes(size); - - for (int i = 0; i < size; ++i) { - sizes[i] = this->nodes[nodeNames[i]]->getNumStates(); + auto dimensions = vector(); + for (auto [name, node] : nodes) { + // Get dimensions of the CPT + dimensions.clear(); + dimensions.push_back(node->getNumStates()); + for (auto father : node->getParents()) { + dimensions.push_back(father->getNumStates()); } - - torch::Tensor counts = torch::zeros(sizes, torch::kLong); - - int dataSize = this->dataset[nodeNames[0]].size(); - torch::List> indices; - for (int dataIdx = 0; dataIdx < dataSize; ++dataIdx) { - indices.clear(); - for (int i = 0; i < size; ++i) { - indices.push_back(torch::tensor(this->dataset[nodeNames[i]][dataIdx], torch::kLong)); + auto length = dimensions.size(); + // Create a tensor of zeros with the dimensions of the CPT + torch::Tensor cpt = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; + // Fill table with counts + for (int n_sample = 0; n_sample < dataset[name].size(); ++n_sample) { + torch::List> coordinates; + coordinates.push_back(torch::tensor(dataset[name][n_sample])); + for (auto father : node->getParents()) { + coordinates.push_back(torch::tensor(dataset[father->getName()][n_sample])); } - //torch::Tensor indices = torch::stack(idx); - counts.index_put_({ indices }, counts.index({ indices }) + 1); + // Increment the count of the corresponding coordinate + cpt.index_put_({ coordinates }, cpt.index({ coordinates }) + 1); } - - return counts; - }; - - // Lambda function to compute marginal counts of states - auto marginalCounts = [](const torch::Tensor& jointCounts) { - return jointCounts.sum(-1); - }; - - for (auto& pair : nodes) { - Node* node = pair.second; - - // Create a list of names of the node and its parents - std::vector nodeNames; - nodeNames.push_back(node->getName()); - for (Node* parent : node->getParents()) { - nodeNames.push_back(parent->getName()); - } - - // Compute counts and normalize to get probabilities - torch::Tensor counts = jointCounts(nodeNames) + laplaceSmoothing; - torch::Tensor parentCounts = marginalCounts(counts); - parentCounts = parentCounts.unsqueeze(-1); - - // The CPT is represented as a tensor and stored in the Node - node->setCPT((counts.to(torch::kDouble) / parentCounts.to(torch::kDouble))); + // Normalize the counts + cpt = cpt / cpt.sum(0); + // store thre resulting cpt in the node + node->setCPT(cpt); } } - } diff --git a/test.cc b/test.cc index e1166a3..4784189 100644 --- a/test.cc +++ b/test.cc @@ -24,30 +24,37 @@ int main() { - torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example - int i = 3, j = 1, k = 2; // Indices for the cell you want to update + //torch::Tensor t = torch::rand({ 5, 4, 3 }); // 3D tensor for this example + //int i = 3, j = 1, k = 2; // Indices for the cell you want to update // Print original tensor + torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example std::cout << t << std::endl; + std::cout << "sum(0)" << std::endl; + std::cout << t.sum(0) << std::endl; + std::cout << "sum(1)" << std::endl; + std::cout << t.sum(1) << std::endl; + std::cout << "Normalized" << std::endl; + std::cout << t / t.sum(0) << std::endl; // New value - torch::Tensor new_val = torch::tensor(10.0f); + // torch::Tensor new_val = torch::tensor(10.0f); - // Indices for the cell you want to update - std::vector indices; - indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension - indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension - indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension - //torch::ArrayRef indices_ref(indices); - // Update cell - //torch::Tensor result = torch::stack(indices); - //torch::List> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) }; - torch::List> indices_list; - indices_list.push_back(torch::tensor(i)); - indices_list.push_back(torch::tensor(j)); - indices_list.push_back(torch::tensor(k)); - //t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val); - t.index_put_(indices_list, new_val); + // // Indices for the cell you want to update + // std::vector indices; + // indices.push_back(torch::tensor(i)); // Replace i with your index for the 1st dimension + // indices.push_back(torch::tensor(j)); // Replace j with your index for the 2nd dimension + // indices.push_back(torch::tensor(k)); // Replace k with your index for the 3rd dimension + // //torch::ArrayRef indices_ref(indices); + // // Update cell + // //torch::Tensor result = torch::stack(indices); + // //torch::List> indices_list = { torch::tensor(i), torch::tensor(j), torch::tensor(k) }; + // torch::List> indices_list; + // indices_list.push_back(torch::tensor(i)); + // indices_list.push_back(torch::tensor(j)); + // indices_list.push_back(torch::tensor(k)); + // //t.index_put_({ torch::tensor(i), torch::tensor(j), torch::tensor(k) }, new_val); + // t.index_put_(indices_list, new_val); - // Print updated tensor - std::cout << t << std::endl; + // // Print updated tensor + // std::cout << t << std::endl; }