Refactor Network and create Metrics class

This commit is contained in:
2023-07-11 22:23:49 +02:00
parent c7e2042c6e
commit d1eaab6408
7 changed files with 137 additions and 190 deletions

View File

@@ -15,6 +15,7 @@ namespace bayesnet {
vector<string> features;
string className;
int laplaceSmoothing;
torch::Tensor samples;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&);
vector<double> exactInference(map<string, int>&);
@@ -24,12 +25,12 @@ namespace bayesnet {
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
double mutualInformation(torch::Tensor&, torch::Tensor&);
public:
torch::Tensor samples;
Network();
Network(float, int);
Network(float);
Network(Network&);
~Network();
torch::Tensor& getSamples();
float getmaxThreads();
void addNode(string, int);
void addEdge(const string, const string);