Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
9 changed files with 87 additions and 49 deletions

View File

@@ -3,23 +3,26 @@
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
using namespace std;
using namespace torch;
class Metrics {
private:
torch::Tensor samples;
Tensor samples;
vector<string> features;
string className;
int classNumStates;
vector<pair<string, string>> doCombinations(const vector<string>&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
public:
double mutualInformation(torch::Tensor&, torch::Tensor&);
Metrics(torch::Tensor&, vector<string>&, string&, int);
Metrics() = default;
Metrics(Tensor&, vector<string>&, string&, int);
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
double entropy(Tensor&);
double conditionalEntropy(Tensor&, Tensor&);
double mutualInformation(Tensor&, Tensor&);
vector<float> conditionalEdgeWeights();
torch::Tensor conditionalEdge();
Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&);
vector<pair<int, int>> maximumSpanningTree(Tensor& weights);
};
}
#endif