Refactor Library 2 include a Platform/ Experiments

This commit is contained in:
2023-07-20 10:40:08 +02:00
parent 2f5bd0ea7e
commit 5f70449091
42 changed files with 269 additions and 42 deletions

View File

@@ -1,48 +0,0 @@
#ifndef CLASSIFIERS_H
#define CLASSIFIERS_H
#include <torch/torch.h>
#include "Network.h"
#include "Metrics.hpp"
using namespace std;
using namespace torch;
namespace bayesnet {
class BaseClassifier {
private:
bool fitted;
BaseClassifier& build(vector<string>& features, string className, map<string, vector<int>>& states);
protected:
Network model;
int m, n; // m: number of samples, n: number of features
Tensor X;
vector<vector<int>> Xv;
Tensor y;
vector<int> yv;
Tensor dataset;
Metrics metrics;
vector<string> features;
string className;
map<string, vector<int>> states;
void checkFitParameters();
virtual void train() = 0;
public:
BaseClassifier(Network model);
virtual ~BaseClassifier() = default;
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes();
int getNumberOfNodes();
int getNumberOfEdges();
Tensor predict(Tensor& X);
vector<int> predict(vector<vector<int>>& X);
float score(Tensor& X, Tensor& y);
float score(vector<vector<int>>& X, vector<int>& y);
vector<string> show();
virtual vector<string> graph(string title) = 0;
};
}
#endif