BayesNet/src/Metrics.hpp

25 lines
851 B
C++
Raw Normal View History

#ifndef BAYESNET_METRICS_H
#define BAYESNET_METRICS_H
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
class Metrics {
private:
2023-07-11 23:05:24 +00:00
torch::Tensor samples;
2023-07-12 01:23:28 +00:00
vector<string> features;
string className;
int classNumStates;
vector<pair<string, string>> doCombinations(const vector<string>&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
public:
2023-07-13 01:15:42 +00:00
double mutualInformation(torch::Tensor&, torch::Tensor&);
Metrics(torch::Tensor&, vector<string>&, string&, int);
2023-07-12 01:23:28 +00:00
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
2023-07-11 23:05:24 +00:00
vector<float> conditionalEdgeWeights();
2023-07-13 01:44:33 +00:00
torch::Tensor conditionalEdge();
};
}
#endif