diff --git a/bayesnet/classifiers/Proposal.cc b/bayesnet/classifiers/Proposal.cc index 6974b70..a2a25ad 100644 --- a/bayesnet/classifiers/Proposal.cc +++ b/bayesnet/classifiers/Proposal.cc @@ -118,37 +118,31 @@ namespace bayesnet { } return states; } - map> Proposal::fit_local_discretization(const torch::Tensor& y, map> states) + map> Proposal::fit_local_discretization(const torch::Tensor& y, map> states_) { // Discretize the continuous input data and build pDataset (Classifier::dataset) - // We expect to have in states for numeric features an empty vector and for discretized features a vector of states int m = Xf.size(1); int n = Xf.size(0); + map> states; pDataset = torch::zeros({ n + 1, m }, torch::kInt32); auto yv = std::vector(y.data_ptr(), y.data_ptr() + y.size(0)); // discretize input data by feature(row) std::unique_ptr discretizer; for (auto i = 0; i < pFeatures.size(); ++i) { + if (discretizationType == discretization_t::BINQ) { + discretizer = std::make_unique(ld_params.proposed_cuts, mdlp::strategy_t::QUANTILE); + } else if (discretizationType == discretization_t::BINU) { + discretizer = std::make_unique(ld_params.proposed_cuts, mdlp::strategy_t::UNIFORM); + } else { // Default is MDLP + discretizer = std::make_unique(ld_params.min_length, ld_params.max_depth, ld_params.proposed_cuts); + } auto Xt_ptr = Xf.index({ i }).data_ptr(); auto Xt = std::vector(Xt_ptr, Xt_ptr + Xf.size(1)); - if (states[pFeatures[i]].empty()) { - // If the feature is numeric, we discretize it - if (discretizationType == discretization_t::BINQ) { - discretizer = std::make_unique(ld_params.proposed_cuts, mdlp::strategy_t::QUANTILE); - } else if (discretizationType == discretization_t::BINU) { - discretizer = std::make_unique(ld_params.proposed_cuts, mdlp::strategy_t::UNIFORM); - } else { // Default is MDLP - discretizer = std::make_unique(ld_params.min_length, ld_params.max_depth, ld_params.proposed_cuts); - } - pDataset.index_put_({ i, "..." }, torch::tensor(discretizer->fit_transform(Xt, yv))); - int n_states = discretizer->getCutPoints().size() + 1; - auto xStates = std::vector(n_states); - iota(xStates.begin(), xStates.end(), 0); - states[pFeatures[i]] = xStates; - } else { - // If the feature is categorical, we just copy it - pDataset.index_put_({ i, "..." }, Xf[i].to(torch::kInt32)); - } + discretizer->fit(Xt, yv); + pDataset.index_put_({ i, "..." }, torch::tensor(discretizer->transform(Xt))); + auto xStates = std::vector(discretizer->getCutPoints().size() + 1); + iota(xStates.begin(), xStates.end(), 0); + states[pFeatures[i]] = xStates; discretizers[pFeatures[i]] = std::move(discretizer); } int n_classes = torch::max(y).item() + 1;