Make fit build the network

This commit is contained in:
2023-06-30 02:46:06 +02:00
parent 31c22898de
commit 0a31aa2ff1
13 changed files with 580 additions and 82 deletions

View File

@@ -1,5 +1,7 @@
#include "Network.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr) {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr) {}
Network::~Network()
{
for (auto& pair : nodes) {
@@ -70,46 +72,62 @@ namespace bayesnet {
{
return nodes;
}
void Network::fit(const vector<vector<int>>& dataset, const int smoothing)
void Network::buildNetwork(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
{
auto jointCounts = [](const vector<vector<int>>& data, const vector<int>& indices, int numStates) {
int size = indices.size();
vector<int64_t> sizes(size, numStates);
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
for (const auto& row : data) {
int idx = 0;
for (int i = 0; i < size; ++i) {
idx = idx * numStates + row[indices[i]];
}
counts.view({ -1 }).add_(idx, 1);
}
return counts;
};
auto marginalCounts = [](const torch::Tensor& jointCounts) {
return jointCounts.sum(-1);
};
for (auto& pair : nodes) {
Node* node = pair.second;
vector<int> indices;
for (const auto& parent : node->getParents()) {
indices.push_back(nodes[parent->getName()]->getId());
}
indices.push_back(node->getId());
for (auto& child : node->getChildren()) {
torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing;
torch::Tensor parentCounts = marginalCounts(counts);
parentCounts = parentCounts.unsqueeze(-1);
torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
setCPD(node->getCPDKey(child), cpt);
}
// Add features as nodes to the network
for (int i = 0; i < featureNames.size(); ++i) {
addNode(featureNames[i], *max_element(dataset[i].begin(), dataset[i].end()) + 1);
}
// Add class as node to the network
addNode(className, *max_element(labels.begin(), labels.end()) + 1);
// Add edges from class to features => naive Bayes
for (auto feature : featureNames) {
addEdge(className, feature);
}
}
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
{
buildNetwork(dataset, labels, featureNames, className);
//estimateParameters(dataset);
// auto jointCounts = [](const vector<vector<int>>& data, const vector<int>& indices, int numStates) {
// int size = indices.size();
// vector<int64_t> sizes(size, numStates);
// torch::Tensor counts = torch::zeros(sizes, torch::kLong);
// for (const auto& row : data) {
// int idx = 0;
// for (int i = 0; i < size; ++i) {
// idx = idx * numStates + row[indices[i]];
// }
// counts.view({ -1 }).add_(idx, 1);
// }
// return counts;
// };
// auto marginalCounts = [](const torch::Tensor& jointCounts) {
// return jointCounts.sum(-1);
// };
// for (auto& pair : nodes) {
// Node* node = pair.second;
// vector<int> indices;
// for (const auto& parent : node->getParents()) {
// indices.push_back(nodes[parent->getName()]->getId());
// }
// indices.push_back(node->getId());
// for (auto& child : node->getChildren()) {
// torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + laplaceSmoothing;
// torch::Tensor parentCounts = marginalCounts(counts);
// parentCounts = parentCounts.unsqueeze(-1);
// torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
// setCPD(node->getCPDKey(child), cpt);
// }
// }
}
torch::Tensor& Network::getCPD(const string& key)