2023-07-11 20:23:49 +00:00
|
|
|
#ifndef BAYESNET_METRICS_H
|
|
|
|
#define BAYESNET_METRICS_H
|
|
|
|
#include <torch/torch.h>
|
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
|
|
|
using namespace std;
|
|
|
|
namespace bayesnet {
|
|
|
|
class Metrics {
|
|
|
|
private:
|
2023-07-11 23:05:24 +00:00
|
|
|
torch::Tensor samples;
|
2023-07-12 01:23:28 +00:00
|
|
|
vector<string> features;
|
|
|
|
string className;
|
2023-07-11 20:23:49 +00:00
|
|
|
int classNumStates;
|
|
|
|
vector<pair<string, string>> doCombinations(const vector<string>&);
|
|
|
|
double entropy(torch::Tensor&);
|
|
|
|
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
|
|
|
public:
|
2023-07-13 01:15:42 +00:00
|
|
|
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
2023-07-11 20:23:49 +00:00
|
|
|
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
2023-07-12 01:23:28 +00:00
|
|
|
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
|
2023-07-11 23:05:24 +00:00
|
|
|
vector<float> conditionalEdgeWeights();
|
2023-07-13 01:44:33 +00:00
|
|
|
torch::Tensor conditionalEdge();
|
2023-07-11 20:23:49 +00:00
|
|
|
};
|
|
|
|
}
|
|
|
|
#endif
|