// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #ifndef PROPOSAL_H #define PROPOSAL_H #include #include #include #include #include "bayesnet/network/Network.h" #include "Classifier.h" namespace bayesnet { class Proposal { public: Proposal(torch::Tensor& pDataset, std::vector& features_, std::string& className_); virtual ~Proposal(); protected: void checkInput(const torch::Tensor& X, const torch::Tensor& y); torch::Tensor prepareX(torch::Tensor& X); map> localDiscretizationProposal(const map>& states, Network& model); map> fit_local_discretization(const torch::Tensor& y); torch::Tensor Xf; // X continuous nxm tensor torch::Tensor y; // y discrete nx1 tensor map discretizers; private: std::vector factorize(const std::vector& labels_t); torch::Tensor& pDataset; // (n+1)xm tensor std::vector& pFeatures; std::string& pClassName; }; } #endif