Begin implementing KDB

This commit is contained in:
2023-07-13 03:15:42 +02:00
parent c5386d66fc
commit 8b0aa5ccfb
8 changed files with 201 additions and 3 deletions

39
src/BaseClassifier.h Normal file
View File

@@ -0,0 +1,39 @@
#ifndef CLASSIFIERS_H
#include <torch/torch.h>
#include "Network.h"
using namespace std;
using namespace torch;
namespace bayesnet {
class BaseClassifier {
private:
BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
protected:
Network model;
int m, n; // m: number of samples, n: number of features
Tensor X;
Tensor y;
Tensor dataset;
vector<string> features;
string className;
map<string, vector<int>> states;
void checkFitParameters();
virtual void train() = 0;
public:
BaseClassifier(Network model);
Tensor& getX();
vector<string>& getFeatures();
string& getClassName();
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
Tensor predict(Tensor& X);
float score(Tensor& X, Tensor& y);
};
}
#endif