Refactor BaseClassifier and begin TAN impl.
This commit is contained in:
@@ -3,23 +3,26 @@
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
class Metrics {
|
||||
private:
|
||||
torch::Tensor samples;
|
||||
Tensor samples;
|
||||
vector<string> features;
|
||||
string className;
|
||||
int classNumStates;
|
||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||
double entropy(torch::Tensor&);
|
||||
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
||||
public:
|
||||
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
||||
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
||||
Metrics() = default;
|
||||
Metrics(Tensor&, vector<string>&, string&, int);
|
||||
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
|
||||
double entropy(Tensor&);
|
||||
double conditionalEntropy(Tensor&, Tensor&);
|
||||
double mutualInformation(Tensor&, Tensor&);
|
||||
vector<float> conditionalEdgeWeights();
|
||||
torch::Tensor conditionalEdge();
|
||||
Tensor conditionalEdge();
|
||||
vector<pair<string, string>> doCombinations(const vector<string>&);
|
||||
vector<pair<int, int>> maximumSpanningTree(Tensor& weights);
|
||||
};
|
||||
}
|
||||
#endif
|
Reference in New Issue
Block a user