2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
2023-06-29 20:00:41 +00:00
|
|
|
#ifndef NODE_H
|
|
|
|
#define NODE_H
|
2023-07-11 15:42:20 +00:00
|
|
|
#include <unordered_set>
|
2023-06-29 20:00:41 +00:00
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
2024-03-08 21:20:54 +00:00
|
|
|
#include <torch/torch.h>
|
2023-06-29 20:00:41 +00:00
|
|
|
namespace bayesnet {
|
|
|
|
class Node {
|
|
|
|
private:
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string name;
|
|
|
|
std::vector<Node*> parents;
|
|
|
|
std::vector<Node*> children;
|
2023-07-05 16:38:54 +00:00
|
|
|
int numStates; // number of states of the variable
|
|
|
|
torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ...
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<int64_t> dimensions; // dimensions of the cpTable
|
|
|
|
std::vector<std::pair<std::string, std::string>> combinations(const std::vector<std::string>&);
|
2023-10-11 19:17:26 +00:00
|
|
|
public:
|
2023-11-08 17:45:35 +00:00
|
|
|
explicit Node(const std::string&);
|
2023-07-25 23:39:01 +00:00
|
|
|
void clear();
|
2023-06-30 19:24:12 +00:00
|
|
|
void addParent(Node*);
|
2023-06-29 21:53:33 +00:00
|
|
|
void addChild(Node*);
|
|
|
|
void removeParent(Node*);
|
|
|
|
void removeChild(Node*);
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string getName() const;
|
|
|
|
std::vector<Node*>& getParents();
|
|
|
|
std::vector<Node*>& getChildren();
|
2023-06-29 20:00:41 +00:00
|
|
|
torch::Tensor& getCPT();
|
2023-11-08 17:45:35 +00:00
|
|
|
void computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights);
|
2023-06-29 20:00:41 +00:00
|
|
|
int getNumStates() const;
|
2023-07-01 12:45:44 +00:00
|
|
|
void setNumStates(int);
|
2023-07-02 18:39:13 +00:00
|
|
|
unsigned minFill();
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::string> graph(const std::string& clasName); // Returns a std::vector of std::strings representing the graph in graphviz format
|
|
|
|
float getFactorValue(std::map<std::string, int>&);
|
2023-06-29 20:00:41 +00:00
|
|
|
};
|
|
|
|
}
|
|
|
|
#endif
|