Begin with parameter estimation
This commit is contained in:
11
Network.h
11
Network.h
@@ -3,12 +3,16 @@
|
||||
#include "Node.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace bayesnet {
|
||||
class Network {
|
||||
private:
|
||||
map<string, Node*> nodes;
|
||||
map<string, torch::Tensor> cpds; // Map from CPD key to CPD tensor
|
||||
map<string, vector<int>> dataset;
|
||||
Node* root;
|
||||
vector<string> features;
|
||||
string className;
|
||||
int laplaceSmoothing;
|
||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||
public:
|
||||
@@ -19,9 +23,8 @@ namespace bayesnet {
|
||||
void addEdge(const string, const string);
|
||||
map<string, Node*>& getNodes();
|
||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||
void buildNetwork(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||
torch::Tensor& getCPD(const string&);
|
||||
void setCPD(const string&, const torch::Tensor&);
|
||||
void estimateParameters();
|
||||
void buildNetwork();
|
||||
void setRoot(string);
|
||||
Node* getRoot();
|
||||
};
|
||||
|
Reference in New Issue
Block a user