2023-07-22 21:07:56 +00:00
|
|
|
#ifndef BASE_H
|
|
|
|
#define BASE_H
|
2023-07-13 01:15:42 +00:00
|
|
|
#include <torch/torch.h>
|
2023-07-22 21:07:56 +00:00
|
|
|
#include <vector>
|
2023-07-13 01:15:42 +00:00
|
|
|
namespace bayesnet {
|
2023-07-22 21:07:56 +00:00
|
|
|
using namespace std;
|
2023-07-13 01:15:42 +00:00
|
|
|
class BaseClassifier {
|
|
|
|
public:
|
2023-07-22 21:07:56 +00:00
|
|
|
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
2023-07-23 12:10:28 +00:00
|
|
|
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
|
2023-07-30 17:00:02 +00:00
|
|
|
torch::Tensor virtual predict(torch::Tensor& X) = 0;
|
2023-07-22 21:07:56 +00:00
|
|
|
vector<int> virtual predict(vector<vector<int>>& X) = 0;
|
|
|
|
float virtual score(vector<vector<int>>& X, vector<int>& y) = 0;
|
2023-07-23 12:10:28 +00:00
|
|
|
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
|
2023-07-26 17:01:39 +00:00
|
|
|
int virtual getNumberOfNodes() = 0;
|
|
|
|
int virtual getNumberOfEdges() = 0;
|
|
|
|
int virtual getNumberOfStates() = 0;
|
2023-07-22 21:07:56 +00:00
|
|
|
vector<string> virtual show() = 0;
|
2023-07-31 17:53:55 +00:00
|
|
|
vector<string> virtual graph(const string& title = "") = 0;
|
2023-07-14 23:05:36 +00:00
|
|
|
virtual ~BaseClassifier() = default;
|
2023-07-26 23:56:06 +00:00
|
|
|
const string inline getVersion() const { return "0.1.0"; };
|
2023-08-01 22:56:52 +00:00
|
|
|
vector<string> virtual topological_order() = 0;
|
2023-07-13 01:15:42 +00:00
|
|
|
};
|
|
|
|
}
|
2023-07-22 21:07:56 +00:00
|
|
|
#endif
|