BayesNet/bayesnet/network/Node.h

42 lines
1.7 KiB
C
Raw Permalink Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-06-29 20:00:41 +00:00
#ifndef NODE_H
#define NODE_H
#include <unordered_set>
2023-06-29 20:00:41 +00:00
#include <vector>
#include <string>
2024-03-08 21:20:54 +00:00
#include <torch/torch.h>
2023-06-29 20:00:41 +00:00
namespace bayesnet {
class Node {
2023-10-11 19:17:26 +00:00
public:
2023-11-08 17:45:35 +00:00
explicit Node(const std::string&);
2023-07-25 23:39:01 +00:00
void clear();
2023-06-30 19:24:12 +00:00
void addParent(Node*);
2023-06-29 21:53:33 +00:00
void addChild(Node*);
void removeParent(Node*);
void removeChild(Node*);
2023-11-08 17:45:35 +00:00
std::string getName() const;
std::vector<Node*>& getParents();
std::vector<Node*>& getChildren();
2023-06-29 20:00:41 +00:00
torch::Tensor& getCPT();
2023-11-08 17:45:35 +00:00
void computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights);
2023-06-29 20:00:41 +00:00
int getNumStates() const;
void setNumStates(int);
unsigned minFill();
2023-11-08 17:45:35 +00:00
std::vector<std::string> graph(const std::string& clasName); // Returns a std::vector of std::strings representing the graph in graphviz format
float getFactorValue(std::map<std::string, int>&);
private:
std::string name;
std::vector<Node*> parents;
std::vector<Node*> children;
int numStates = 0; // number of states of the variable
torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ...
std::vector<int64_t> dimensions; // dimensions of the cpTable
std::vector<std::pair<std::string, std::string>> combinations(const std::vector<std::string>&);
2023-06-29 20:00:41 +00:00
};
}
#endif