BayesNet/bayesnet/classifiers/Proposal.h

37 lines
1.4 KiB
C
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-08-04 11:05:12 +00:00
#ifndef PROPOSAL_H
#define PROPOSAL_H
#include <string>
#include <map>
#include <torch/torch.h>
2024-03-08 00:13:30 +00:00
#include <CPPFImdlp.h>
2024-03-08 21:20:54 +00:00
#include "bayesnet/network/Network.h"
#include "Classifier.h"
2023-08-04 11:05:12 +00:00
namespace bayesnet {
class Proposal {
public:
2023-11-08 17:45:35 +00:00
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_);
2023-08-04 17:42:18 +00:00
virtual ~Proposal();
2023-08-04 11:05:12 +00:00
protected:
2023-08-24 10:09:35 +00:00
void checkInput(const torch::Tensor& X, const torch::Tensor& y);
torch::Tensor prepareX(torch::Tensor& X);
2023-11-08 17:45:35 +00:00
map<std::string, std::vector<int>> localDiscretizationProposal(const map<std::string, std::vector<int>>& states, Network& model);
map<std::string, std::vector<int>> fit_local_discretization(const torch::Tensor& y);
2023-08-04 11:05:12 +00:00
torch::Tensor Xf; // X continuous nxm tensor
torch::Tensor y; // y discrete nx1 tensor
2023-11-08 17:45:35 +00:00
map<std::string, mdlp::CPPFImdlp*> discretizers;
2023-08-04 11:05:12 +00:00
private:
std::vector<int> factorize(const std::vector<std::string>& labels_t);
torch::Tensor& pDataset; // (n+1)xm tensor
2023-11-08 17:45:35 +00:00
std::vector<std::string>& pFeatures;
std::string& pClassName;
2023-08-04 11:05:12 +00:00
};
}
#endif