Refactor New classifiers to extract predict

This commit is contained in:
2023-08-05 18:39:48 +02:00
parent 1a09ccca4c
commit 7f45495837
6 changed files with 35 additions and 50 deletions

View File

@@ -5,6 +5,7 @@
#include <torch/torch.h>
#include "Network.h"
#include "CPPFImdlp.h"
#include "Classifier.h"
namespace bayesnet {
class Proposal {
@@ -12,6 +13,7 @@ namespace bayesnet {
Proposal(vector<vector<int>>& Xv_, vector<int>& yv_, vector<string>& features_, string& className_);
virtual ~Proposal();
protected:
torch::Tensor prepareX(torch::Tensor& X);
void localDiscretizationProposal(map<string, vector<int>>& states, Network& model);
void fit_local_discretization(map<string, vector<int>>& states, torch::Tensor& y);
torch::Tensor Xf; // X continuous nxm tensor