Add cycle detect adding edges

This commit is contained in:
2023-06-29 23:53:33 +02:00
parent d59bf03a51
commit 31c22898de
5 changed files with 87 additions and 18 deletions

View File

@@ -8,6 +8,8 @@ namespace bayesnet {
private:
map<string, Node*> nodes;
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
Node* root = nullptr;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
public:
~Network();
void addNode(string, int);
@@ -16,6 +18,8 @@ namespace bayesnet {
void fit(const vector<vector<int>>&, const int);
torch::Tensor& getCPD(const string&);
void setCPD(const string&, const torch::Tensor&);
void setRoot(string);
Node* getRoot();
};
}
#endif