Add entropy, conditionalEntropy, mutualInformation and conditionalEdgeWeight methods

This commit is contained in:
2023-07-11 17:42:20 +02:00
parent 3750662f2c
commit c7e2042c6e
8 changed files with 683 additions and 43 deletions

View File

@@ -19,7 +19,12 @@ namespace bayesnet {
vector<double> predict_sample(const vector<int>&);
vector<double> exactInference(map<string, int>&);
double computeFactor(map<string, int>&);
double mutual_info(torch::Tensor&, torch::Tensor&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
double mutualInformation(torch::Tensor&, torch::Tensor&);
public:
torch::Tensor samples;
Network();
Network(float, int);
Network(float);
@@ -35,6 +40,8 @@ namespace bayesnet {
string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
vector<int> predict(const vector<vector<int>>&);
//Computes the conditional edge weight of variable index u and v conditioned on class_node
torch::Tensor conditionalEdgeWeight();
vector<vector<double>> predict_proba(const vector<vector<int>>&);
double score(const vector<vector<int>>&, const vector<int>&);
inline string version() { return "0.1.0"; }