2023-06-29 20:00:41 +00:00
|
|
|
#ifndef NODE_H
|
|
|
|
#define NODE_H
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
|
|
|
namespace bayesnet {
|
|
|
|
using namespace std;
|
|
|
|
class Node {
|
|
|
|
private:
|
|
|
|
static int next_id;
|
|
|
|
const int id;
|
|
|
|
string name;
|
|
|
|
vector<Node*> parents;
|
|
|
|
vector<Node*> children;
|
|
|
|
int numStates;
|
|
|
|
torch::Tensor cpt;
|
|
|
|
public:
|
2023-06-29 21:53:33 +00:00
|
|
|
Node(const std::string&, int);
|
|
|
|
void addParent(Node*);
|
|
|
|
void addChild(Node*);
|
|
|
|
void removeParent(Node*);
|
|
|
|
void removeChild(Node*);
|
2023-06-29 20:00:41 +00:00
|
|
|
string getName() const;
|
|
|
|
vector<Node*>& getParents();
|
|
|
|
vector<Node*>& getChildren();
|
|
|
|
torch::Tensor& getCPT();
|
2023-06-29 21:53:33 +00:00
|
|
|
void setCPT(const torch::Tensor&);
|
2023-06-29 20:00:41 +00:00
|
|
|
int getNumStates() const;
|
|
|
|
int getId() const { return id; }
|
|
|
|
string getCPDKey(const Node*) const;
|
|
|
|
};
|
|
|
|
}
|
|
|
|
#endif
|