From 182b52a887c36b7bec335bfe2da9cfe945a59ffa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sat, 12 Aug 2023 16:16:17 +0200 Subject: [PATCH] Add states as result in Proposal methods --- src/BayesNet/KDBLd.cc | 2 +- src/BayesNet/Proposal.cc | 6 ++++-- src/BayesNet/Proposal.h | 4 ++-- src/BayesNet/SPODELd.cc | 4 ++-- src/BayesNet/TANLd.cc | 2 +- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/BayesNet/KDBLd.cc b/src/BayesNet/KDBLd.cc index 724a053..c236a1c 100644 --- a/src/BayesNet/KDBLd.cc +++ b/src/BayesNet/KDBLd.cc @@ -15,7 +15,7 @@ namespace bayesnet { // We have discretized the input data // 1st we need to fit the model to build the normal KDB structure, KDB::fit initializes the base Bayesian network KDB::fit(dataset, features, className, states); - localDiscretizationProposal(states, model); + states = localDiscretizationProposal(states, model); return *this; } Tensor KDBLd::predict(Tensor& X) diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index d53094d..eef0088 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -9,12 +9,13 @@ namespace bayesnet { delete value; } } - void Proposal::localDiscretizationProposal(map>& states, Network& model) + map> Proposal::localDiscretizationProposal(const map>& oldStates, Network& model) { // order of local discretization is important. no good 0, 1, 2... // although we rediscretize features after the local discretization of every feature auto order = model.topological_sort(); auto& nodes = model.getNodes(); + map> states = oldStates; vector indicesToReDiscretize; bool upgrade = false; // Flag to check if we need to upgrade the model for (auto feature : order) { @@ -66,8 +67,9 @@ namespace bayesnet { } model.fit(pDataset, pFeatures, pClassName, states); } + return states; } - map> Proposal::fit_local_discretization(torch::Tensor& y) + map> Proposal::fit_local_discretization(const torch::Tensor& y) { // Discretize the continuous input data and build pDataset (Classifier::dataset) int m = Xf.size(1); diff --git a/src/BayesNet/Proposal.h b/src/BayesNet/Proposal.h index 10814c2..f5eabda 100644 --- a/src/BayesNet/Proposal.h +++ b/src/BayesNet/Proposal.h @@ -14,8 +14,8 @@ namespace bayesnet { virtual ~Proposal(); protected: torch::Tensor prepareX(torch::Tensor& X); - void localDiscretizationProposal(map>& states, Network& model); - map> fit_local_discretization(torch::Tensor& y); + map> localDiscretizationProposal(const map>& states, Network& model); + map> fit_local_discretization(const torch::Tensor& y); torch::Tensor Xf; // X continuous nxm tensor torch::Tensor y; // y discrete nx1 tensor map discretizers; diff --git a/src/BayesNet/SPODELd.cc b/src/BayesNet/SPODELd.cc index 9683b7e..8a38160 100644 --- a/src/BayesNet/SPODELd.cc +++ b/src/BayesNet/SPODELd.cc @@ -15,7 +15,7 @@ namespace bayesnet { // We have discretized the input data // 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network SPODE::fit(dataset, features, className, states); - localDiscretizationProposal(states, model); + states = localDiscretizationProposal(states, model); return *this; } SPODELd& SPODELd::fit(torch::Tensor& dataset, vector& features_, string className_, map>& states_) @@ -31,7 +31,7 @@ namespace bayesnet { // We have discretized the input data // 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network SPODE::fit(dataset, features, className, states); - localDiscretizationProposal(states, model); + states = localDiscretizationProposal(states, model); return *this; } diff --git a/src/BayesNet/TANLd.cc b/src/BayesNet/TANLd.cc index a30cba8..e0fdebd 100644 --- a/src/BayesNet/TANLd.cc +++ b/src/BayesNet/TANLd.cc @@ -15,7 +15,7 @@ namespace bayesnet { // We have discretized the input data // 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network TAN::fit(dataset, features, className, states); - localDiscretizationProposal(states, model); + states = localDiscretizationProposal(states, model); return *this; } -- 2.45.2