Make fit build the network
This commit is contained in:
94
Network.cc
94
Network.cc
@@ -1,5 +1,7 @@
|
||||
#include "Network.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr) {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr) {}
|
||||
Network::~Network()
|
||||
{
|
||||
for (auto& pair : nodes) {
|
||||
@@ -70,46 +72,62 @@ namespace bayesnet {
|
||||
{
|
||||
return nodes;
|
||||
}
|
||||
void Network::fit(const vector<vector<int>>& dataset, const int smoothing)
|
||||
void Network::buildNetwork(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
||||
{
|
||||
auto jointCounts = [](const vector<vector<int>>& data, const vector<int>& indices, int numStates) {
|
||||
int size = indices.size();
|
||||
vector<int64_t> sizes(size, numStates);
|
||||
torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
||||
|
||||
for (const auto& row : data) {
|
||||
int idx = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
idx = idx * numStates + row[indices[i]];
|
||||
}
|
||||
counts.view({ -1 }).add_(idx, 1);
|
||||
}
|
||||
|
||||
return counts;
|
||||
};
|
||||
|
||||
auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
||||
return jointCounts.sum(-1);
|
||||
};
|
||||
|
||||
for (auto& pair : nodes) {
|
||||
Node* node = pair.second;
|
||||
|
||||
vector<int> indices;
|
||||
for (const auto& parent : node->getParents()) {
|
||||
indices.push_back(nodes[parent->getName()]->getId());
|
||||
}
|
||||
indices.push_back(node->getId());
|
||||
|
||||
for (auto& child : node->getChildren()) {
|
||||
torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + smoothing;
|
||||
torch::Tensor parentCounts = marginalCounts(counts);
|
||||
parentCounts = parentCounts.unsqueeze(-1);
|
||||
|
||||
torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
|
||||
setCPD(node->getCPDKey(child), cpt);
|
||||
}
|
||||
// Add features as nodes to the network
|
||||
for (int i = 0; i < featureNames.size(); ++i) {
|
||||
addNode(featureNames[i], *max_element(dataset[i].begin(), dataset[i].end()) + 1);
|
||||
}
|
||||
// Add class as node to the network
|
||||
addNode(className, *max_element(labels.begin(), labels.end()) + 1);
|
||||
// Add edges from class to features => naive Bayes
|
||||
for (auto feature : featureNames) {
|
||||
addEdge(className, feature);
|
||||
}
|
||||
}
|
||||
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
||||
{
|
||||
buildNetwork(dataset, labels, featureNames, className);
|
||||
//estimateParameters(dataset);
|
||||
|
||||
// auto jointCounts = [](const vector<vector<int>>& data, const vector<int>& indices, int numStates) {
|
||||
// int size = indices.size();
|
||||
// vector<int64_t> sizes(size, numStates);
|
||||
// torch::Tensor counts = torch::zeros(sizes, torch::kLong);
|
||||
|
||||
// for (const auto& row : data) {
|
||||
// int idx = 0;
|
||||
// for (int i = 0; i < size; ++i) {
|
||||
// idx = idx * numStates + row[indices[i]];
|
||||
// }
|
||||
// counts.view({ -1 }).add_(idx, 1);
|
||||
// }
|
||||
|
||||
// return counts;
|
||||
// };
|
||||
|
||||
// auto marginalCounts = [](const torch::Tensor& jointCounts) {
|
||||
// return jointCounts.sum(-1);
|
||||
// };
|
||||
|
||||
// for (auto& pair : nodes) {
|
||||
// Node* node = pair.second;
|
||||
|
||||
// vector<int> indices;
|
||||
// for (const auto& parent : node->getParents()) {
|
||||
// indices.push_back(nodes[parent->getName()]->getId());
|
||||
// }
|
||||
// indices.push_back(node->getId());
|
||||
|
||||
// for (auto& child : node->getChildren()) {
|
||||
// torch::Tensor counts = jointCounts(dataset, indices, node->getNumStates()) + laplaceSmoothing;
|
||||
// torch::Tensor parentCounts = marginalCounts(counts);
|
||||
// parentCounts = parentCounts.unsqueeze(-1);
|
||||
|
||||
// torch::Tensor cpt = counts.to(torch::kDouble) / parentCounts.to(torch::kDouble);
|
||||
// setCPD(node->getCPDKey(child), cpt);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
torch::Tensor& Network::getCPD(const string& key)
|
||||
|
Reference in New Issue
Block a user