Make fit build the network

This commit is contained in:
2023-06-30 02:46:06 +02:00
parent 31c22898de
commit 0a31aa2ff1
13 changed files with 580 additions and 82 deletions

View File

@@ -8,14 +8,18 @@ namespace bayesnet {
private:
map<string, Node*> nodes;
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
Node* root = nullptr;
Node* root;
int laplaceSmoothing;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
public:
Network();
Network(int);
~Network();
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
void fit(const vector<vector<int>>&, const int);
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void buildNetwork(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
torch::Tensor& getCPD(const string&);
void setCPD(const string&, const torch::Tensor&);
void setRoot(string);