BayesNet/src/BayesNet/BaseClassifier.h

23 lines
1.0 KiB
C
Raw Normal View History

2023-07-22 21:07:56 +00:00
#ifndef BASE_H
#define BASE_H
2023-07-13 01:15:42 +00:00
#include <torch/torch.h>
2023-07-22 21:07:56 +00:00
#include <vector>
2023-07-13 01:15:42 +00:00
namespace bayesnet {
2023-07-22 21:07:56 +00:00
using namespace std;
2023-07-13 01:15:42 +00:00
class BaseClassifier {
public:
2023-07-22 21:07:56 +00:00
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
2023-07-22 21:07:56 +00:00
vector<int> virtual predict(vector<vector<int>>& X) = 0;
float virtual score(vector<vector<int>>& X, vector<int>& y) = 0;
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;
int virtual getNumberOfNodes() = 0;
int virtual getNumberOfEdges() = 0;
int virtual getNumberOfStates() = 0;
2023-07-22 21:07:56 +00:00
vector<string> virtual show() = 0;
vector<string> virtual graph(string title = "") = 0;
virtual ~BaseClassifier() = default;
2023-07-26 23:56:06 +00:00
const string inline getVersion() const { return "0.1.0"; };
2023-07-13 01:15:42 +00:00
};
}
2023-07-22 21:07:56 +00:00
#endif