#ifndef NETWORK_H #define NETWORK_H #include "Node.h" #include #include namespace bayesnet { class Network { private: map nodes; map cpds; // Map from CPD key to CPD tensor Node* root = nullptr; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); public: ~Network(); void addNode(string, int); void addEdge(const string, const string); map& getNodes(); void fit(const vector>&, const int); torch::Tensor& getCPD(const string&); void setCPD(const string&, const torch::Tensor&); void setRoot(string); Node* getRoot(); }; } #endif