BayesNet/src/BayesNet/BaseClassifier.h

25 lines
1.1 KiB
C
Raw Normal View History

2023-07-22 21:07:56 +00:00
#ifndef BASE_H
#define BASE_H
2023-07-13 01:15:42 +00:00
#include <torch/torch.h>
2023-07-22 21:07:56 +00:00
#include <vector>
2023-07-13 01:15:42 +00:00
namespace bayesnet {
2023-07-22 21:07:56 +00:00
using namespace std;
2023-07-13 01:15:42 +00:00
class BaseClassifier {
public:
2023-07-22 21:07:56 +00:00
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
2023-07-30 17:00:02 +00:00
torch::Tensor virtual predict(torch::Tensor& X) = 0;
2023-07-22 21:07:56 +00:00
vector<int> virtual predict(vector<vector<int>>& X) = 0;
float virtual score(vector<vector<int>>& X, vector<int>& y) = 0;
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
int virtual getNumberOfNodes() = 0;
int virtual getNumberOfEdges() = 0;
int virtual getNumberOfStates() = 0;
2023-07-22 21:07:56 +00:00
vector<string> virtual show() = 0;
2023-07-31 17:53:55 +00:00
vector<string> virtual graph(const string& title = "") = 0;
virtual ~BaseClassifier() = default;
2023-07-26 23:56:06 +00:00
const string inline getVersion() const { return "0.1.0"; };
2023-08-01 22:56:52 +00:00
vector<string> virtual topological_order() = 0;
2023-07-13 01:15:42 +00:00
};
}
2023-07-22 21:07:56 +00:00
#endif