Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
9 changed files with 87 additions and 49 deletions

View File

@@ -1,6 +1,7 @@
#ifndef CLASSIFIERS_H
#include <torch/torch.h>
#include "Network.h"
#include "Metrics.hpp"
using namespace std;
using namespace torch;
@@ -14,6 +15,7 @@ namespace bayesnet {
Tensor X;
Tensor y;
Tensor dataset;
Metrics metrics;
vector<string> features;
string className;
map<string, vector<int>> states;
@@ -21,14 +23,13 @@ namespace bayesnet {
virtual void train() = 0;
public:
BaseClassifier(Network model);
Tensor& getX();
vector<string>& getFeatures();
string& getClassName();
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes();
Tensor predict(Tensor& X);
float score(Tensor& X, Tensor& y);
void show();
vector<string> show();
vector<int> argsort(vector<float>& nums);
};
}
#endif