Almost complete proposal in TANNew

This commit is contained in:
Ricardo Montañana Gómez 2023-08-02 02:21:55 +02:00
parent cdfb45d2cb
commit f520b40016
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 47 additions and 4 deletions

View File

@ -1,3 +1,4 @@
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc Mst.cc)
target_link_libraries(BayesNet mdlp "${TORCH_LIBRARIES}")
target_link_libraries(BayesNet mdlp arff "${TORCH_LIBRARIES}")

View File

@ -1,4 +1,5 @@
#include "TANNew.h"
#include "ArffFiles.h"
namespace bayesnet {
using namespace std;
@ -12,11 +13,13 @@ namespace bayesnet {
this->features = features;
this->className = className;
Xv = vector<vector<int>>();
auto Xvf = vector<vector<float>>();
yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0));
for (int i = 0; i < features.size(); ++i) {
auto* discretizer = new mdlp::CPPFImdlp();
auto Xt_ptr = X.index({ i }).data_ptr<float>();
auto Xt = vector<float>(Xt_ptr, Xt_ptr + X.size(1));
Xvf.push_back(Xt);
discretizer->fit(Xt, yv);
Xv.push_back(discretizer->transform(Xt));
auto xStates = vector<int>(discretizer->getCutPoints().size() + 1);
@ -28,9 +31,48 @@ namespace bayesnet {
auto yStates = vector<int>(n_classes);
iota(yStates.begin(), yStates.end(), 0);
this->states[className] = yStates;
/*
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
*/
// Now we have standard TAN and now we implement the proposal
// order of local discretization is important. no good 0, 1, 2...
auto order = model.topological_sort();
auto nodes = model.getNodes();
bool upgrade = false; // Flag to check if we need to upgrade the model
for (auto feature : order) {
auto nodeParents = nodes[feature]->getParents();
int index = find(features.begin(), features.end(), feature) - features.begin();
vector<string> parents;
if (parents.size() == 1) continue; // Only has class as parent
upgrade = true;
transform(nodeParents.begin(), nodeParents.end(), back_inserter(parents), [](const auto& p) {return p->getName(); });
// Remove class as parent as it will be added later
parents.erase(remove(parents.begin(), parents.end(), className), parents.end());
vector<int> indices;
transform(parents.begin(), parents.end(), back_inserter(indices), [&](const auto& p) {return find(features.begin(), features.end(), p) - features.begin(); });
// Now we fit the discretizer of the feature conditioned on its parents and the class i.e. discretizer.fit(X[index], X[indices] + y)
vector<string> yJoinParents;
transform(yv.begin(), yv.end(), back_inserter(yJoinParents), [&](const auto& p) {return to_string(p); });
for (auto idx : indices) {
for (int i = 0; i < Xvf[idx].size(); ++i) {
yJoinParents[i] += to_string(Xv[idx][i]);
}
}
auto arff = ArffFiles();
auto yxv = arff.factorize(yJoinParents);
discretizers[feature]->fit(Xvf[index], yxv);
}
if (upgrade) {
// Discretize again X with the new fitted discretizers
Xv = vector<vector<int>>();
for (int i = 0; i < features.size(); ++i) {
auto Xt_ptr = X.index({ i }).data_ptr<float>();
auto Xt = vector<float>(Xt_ptr, Xt_ptr + X.size(1));
Xv.push_back(discretizers[features[i]]->transform(Xt));
auto xStates = vector<int>(discretizers[features[i]]->getCutPoints().size() + 1);
iota(xStates.begin(), xStates.end(), 0);
this->states[features[i]] = xStates;
}
}
TAN::fit(Xv, yv, features, className, this->states);
return *this;
}