Make TANNew same as TAN with local discretization
This commit is contained in:
parent
a18fbe5594
commit
8f8f9773ce
@ -2,20 +2,52 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {}
|
TANNew::TANNew() : TAN(), n_features{ 0 } {}
|
||||||
TANNew::~TANNew() {}
|
TANNew::~TANNew() {}
|
||||||
TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
{
|
{
|
||||||
|
n_features = features.size();
|
||||||
|
this->Xf = torch::transpose(X, 0, 1); // now it is mxn as X comes in nxm
|
||||||
|
this->y = y;
|
||||||
|
this->features = features;
|
||||||
|
this->className = className;
|
||||||
|
Xv = vector<vector<int>>();
|
||||||
|
yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0));
|
||||||
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
|
auto* discretizer = new mdlp::CPPFImdlp();
|
||||||
|
auto Xt_ptr = X.index({ i }).data_ptr<float>();
|
||||||
|
auto Xt = vector<float>(Xt_ptr, Xt_ptr + X.size(1));
|
||||||
|
discretizer->fit(Xt, yv);
|
||||||
|
Xv.push_back(discretizer->transform(Xt));
|
||||||
|
auto xStates = vector<int>(discretizer->getCutPoints().size() + 1);
|
||||||
|
iota(xStates.begin(), xStates.end(), 0);
|
||||||
|
this->states[features[i]] = xStates;
|
||||||
|
discretizers[features[i]] = discretizer;
|
||||||
|
}
|
||||||
|
int n_classes = torch::max(y).item<int>() + 1;
|
||||||
|
auto yStates = vector<int>(n_classes);
|
||||||
|
iota(yStates.begin(), yStates.end(), 0);
|
||||||
|
this->states[className] = yStates;
|
||||||
/*
|
/*
|
||||||
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
|
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
|
||||||
*/
|
*/
|
||||||
TAN::fit(X, y, features, className, states);
|
TAN::fit(Xv, yv, features, className, this->states);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
void TANNew::train()
|
void TANNew::train()
|
||||||
{
|
{
|
||||||
TAN::train();
|
TAN::train();
|
||||||
}
|
}
|
||||||
|
Tensor TANNew::predict(Tensor& X)
|
||||||
|
{
|
||||||
|
auto Xtd = torch::zeros_like(X, torch::kInt32);
|
||||||
|
for (int i = 0; i < X.size(0); ++i) {
|
||||||
|
auto Xt = vector<float>(X[i].data_ptr<float>(), X[i].data_ptr<float>() + X.size(1));
|
||||||
|
auto Xd = discretizers[features[i]]->transform(Xt);
|
||||||
|
Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32));
|
||||||
|
}
|
||||||
|
return TAN::predict(Xtd);
|
||||||
|
}
|
||||||
vector<string> TANNew::graph(const string& name)
|
vector<string> TANNew::graph(const string& name)
|
||||||
{
|
{
|
||||||
return TAN::graph(name);
|
return TAN::graph(name);
|
||||||
|
@ -7,15 +7,17 @@ namespace bayesnet {
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
class TANNew : public TAN {
|
class TANNew : public TAN {
|
||||||
private:
|
private:
|
||||||
mdlp::CPPFImdlp discretizer;
|
map<string, mdlp::CPPFImdlp*> discretizers;
|
||||||
|
int n_features;
|
||||||
|
torch::Tensor Xf; // X continuous
|
||||||
public:
|
public:
|
||||||
TANNew();
|
TANNew();
|
||||||
virtual ~TANNew();
|
virtual ~TANNew();
|
||||||
void train() override;
|
void train() override;
|
||||||
TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
vector<string> graph(const string& name = "TAN") override;
|
vector<string> graph(const string& name = "TAN") override;
|
||||||
|
Tensor predict(Tensor& X) override;
|
||||||
static inline string version() { return "0.0.1"; };
|
static inline string version() { return "0.0.1"; };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !TANNEW_H
|
#endif // !TANNEW_H
|
@ -114,7 +114,7 @@ namespace platform {
|
|||||||
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||||
// Prepare Result
|
// Prepare Result
|
||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);;
|
auto [values, counts] = at::_unique(y);
|
||||||
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
@ -99,7 +99,6 @@ int main(int argc, char** argv)
|
|||||||
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||||
saveResults = true;
|
saveResults = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Begin Processing
|
* Begin Processing
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user