From 8f8f9773ce1932e6af506727f9b7210f751a83bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 1 Aug 2023 13:17:12 +0200 Subject: [PATCH] Make TANNew same as TAN with local discretization --- src/BayesNet/TANNew.cc | 36 ++++++++++++++++++++++++++++++++++-- src/BayesNet/TANNew.h | 6 ++++-- src/Platform/Experiment.cc | 2 +- src/Platform/main.cc | 1 - 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/BayesNet/TANNew.cc b/src/BayesNet/TANNew.cc index 6e7ff4e..f491356 100644 --- a/src/BayesNet/TANNew.cc +++ b/src/BayesNet/TANNew.cc @@ -2,20 +2,52 @@ namespace bayesnet { using namespace std; - TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {} + TANNew::TANNew() : TAN(), n_features{ 0 } {} TANNew::~TANNew() {} TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) { + n_features = features.size(); + this->Xf = torch::transpose(X, 0, 1); // now it is mxn as X comes in nxm + this->y = y; + this->features = features; + this->className = className; + Xv = vector>(); + yv = vector(y.data_ptr(), y.data_ptr() + y.size(0)); + for (int i = 0; i < features.size(); ++i) { + auto* discretizer = new mdlp::CPPFImdlp(); + auto Xt_ptr = X.index({ i }).data_ptr(); + auto Xt = vector(Xt_ptr, Xt_ptr + X.size(1)); + discretizer->fit(Xt, yv); + Xv.push_back(discretizer->transform(Xt)); + auto xStates = vector(discretizer->getCutPoints().size() + 1); + iota(xStates.begin(), xStates.end(), 0); + this->states[features[i]] = xStates; + discretizers[features[i]] = discretizer; + } + int n_classes = torch::max(y).item() + 1; + auto yStates = vector(n_classes); + iota(yStates.begin(), yStates.end(), 0); + this->states[className] = yStates; /* Hay que discretizar los datos de entrada y luego en predict discretizar tambiƩn con el mmismo modelo, hacer un transform solamente. */ - TAN::fit(X, y, features, className, states); + TAN::fit(Xv, yv, features, className, this->states); return *this; } void TANNew::train() { TAN::train(); } + Tensor TANNew::predict(Tensor& X) + { + auto Xtd = torch::zeros_like(X, torch::kInt32); + for (int i = 0; i < X.size(0); ++i) { + auto Xt = vector(X[i].data_ptr(), X[i].data_ptr() + X.size(1)); + auto Xd = discretizers[features[i]]->transform(Xt); + Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32)); + } + return TAN::predict(Xtd); + } vector TANNew::graph(const string& name) { return TAN::graph(name); diff --git a/src/BayesNet/TANNew.h b/src/BayesNet/TANNew.h index b0bddb8..5aecb23 100644 --- a/src/BayesNet/TANNew.h +++ b/src/BayesNet/TANNew.h @@ -7,15 +7,17 @@ namespace bayesnet { using namespace std; class TANNew : public TAN { private: - mdlp::CPPFImdlp discretizer; + map discretizers; + int n_features; + torch::Tensor Xf; // X continuous public: TANNew(); virtual ~TANNew(); void train() override; TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) override; vector graph(const string& name = "TAN") override; + Tensor predict(Tensor& X) override; static inline string version() { return "0.0.1"; }; }; } - #endif // !TANNEW_H \ No newline at end of file diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 58f23cc..a79216c 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -114,7 +114,7 @@ namespace platform { cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush; // Prepare Result auto result = Result(); - auto [values, counts] = at::_unique(y);; + auto [values, counts] = at::_unique(y); result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); int nResults = nfolds * static_cast(randomSeeds.size()); auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 55c0cfe..9a41080 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -99,7 +99,6 @@ int main(int argc, char** argv) filesToTest = platform::Datasets(path, true, platform::ARFF).getNames(); saveResults = true; } - /* * Begin Processing */