Begin Network build

This commit is contained in:
2023-06-29 22:00:41 +02:00
parent 00ed1e0be1
commit 8e9c3483aa
18 changed files with 874 additions and 0 deletions

21
Network.h Normal file
View File

@@ -0,0 +1,21 @@
#ifndef NETWORK_H
#define NETWORK_H
#include "Node.h"
#include <map>
#include <vector>
namespace bayesnet {
class Network {
private:
map<string, Node*> nodes;
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
public:
~Network();
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
void fit(const vector<vector<int>>&, const int);
torch::Tensor& getCPD(const string&);
void setCPD(const string&, const torch::Tensor&);
};
}
#endif