BayesNet/Node.h

33 lines
856 B
C
Raw Normal View History

2023-06-29 20:00:41 +00:00
#ifndef NODE_H
#define NODE_H
#include <torch/torch.h>
#include <vector>
#include <string>
namespace bayesnet {
using namespace std;
class Node {
private:
static int next_id;
const int id;
string name;
vector<Node*> parents;
vector<Node*> children;
int numStates;
torch::Tensor cpt;
public:
2023-06-29 21:53:33 +00:00
Node(const std::string&, int);
void addParent(Node*);
void addChild(Node*);
void removeParent(Node*);
void removeChild(Node*);
2023-06-29 20:00:41 +00:00
string getName() const;
vector<Node*>& getParents();
vector<Node*>& getChildren();
torch::Tensor& getCPT();
2023-06-29 21:53:33 +00:00
void setCPT(const torch::Tensor&);
2023-06-29 20:00:41 +00:00
int getNumStates() const;
int getId() const { return id; }
string getCPDKey(const Node*) const;
};
}
#endif