diff --git a/src/BaseClassifier.cc b/src/BaseClassifier.cc index 5f7b161..8f5c553 100644 --- a/src/BaseClassifier.cc +++ b/src/BaseClassifier.cc @@ -80,16 +80,22 @@ namespace bayesnet { } Tensor BaseClassifier::predict(Tensor& X) { - auto m_ = X.size(0); - auto n_ = X.size(1); - vector> Xd(n_, vector(m_, 0)); - for (auto i = 0; i < n_; i++) { - auto temp = X.index({ "...", i }); - Xd[i] = vector(temp.data_ptr(), temp.data_ptr() + m_); + auto n_models = models.size(); + Tensor y_pred = torch::zeros({ X.size(0), n_models }, torch::kInt64); + for (auto i = 0; i < n_models; ++i) { + y_pred.index_put_({ "...", i }, models[i].predict(X)); } - auto yp = model.predict(Xd); - auto ypred = torch::tensor(yp, torch::kInt64); - return ypred; + auto y_pred_ = y_pred.accessor(); + vector y_pred_final; + for (int i = 0; i < y_pred.size(0); ++i) { + vector votes(states[className].size(), 0); + for (int j = 0; j < y_pred.size(1); ++j) { + votes[y_pred_[i][j]] += 1; + } + auto indices = argsort(votes); + y_pred_final.push_back(indices[0]); + } + return torch::tensor(y_pred_final, torch::kInt64); } float BaseClassifier::score(Tensor& X, Tensor& y) { diff --git a/src/BaseClassifier.h b/src/BaseClassifier.h index aca3066..13a0872 100644 --- a/src/BaseClassifier.h +++ b/src/BaseClassifier.h @@ -1,4 +1,5 @@ #ifndef CLASSIFIERS_H +#define CLASSIFIERS_H #include #include "Network.h" #include "Metrics.hpp" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 59728d5..6131d38 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc) +add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Ensemble.cc b/src/Ensemble.cc new file mode 100644 index 0000000..bd8e4fa --- /dev/null +++ b/src/Ensemble.cc @@ -0,0 +1,61 @@ +#include "Ensemble.h" + +namespace bayesnet { + using namespace std; + using namespace torch; + + Ensemble::Ensemble(BaseClassifier& model) : model(model), models(vector()), m(0), n(0), metrics(Metrics()) {} + Ensemble& Ensemble::build(vector& features, string className, map>& states) + { + + dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1); + this->features = features; + this->className = className; + this->states = states; + auto n_classes = states[className].size(); + metrics = Metrics(dataset, features, className, n_classes); + train(); + return *this; + } + Ensemble& Ensemble::fit(Tensor& X, Tensor& y, vector& features, string className, map>& states) + { + this->X = X; + this->y = y; + auto sizes = X.sizes(); + m = sizes[0]; + n = sizes[1]; + return build(features, className, states); + } + Ensemble& Ensemble::fit(vector>& X, vector& y, vector& features, string className, map>& states) + { + this->X = torch::zeros({ static_cast(X[0].size()), static_cast(X.size()) }, kInt64); + for (int i = 0; i < X.size(); ++i) { + this->X.index_put_({ "...", i }, torch::tensor(X[i], kInt64)); + } + this->y = torch::tensor(y, kInt64); + return build(features, className, states); + } + Tensor Ensemble::predict(Tensor& X) + { + auto m_ = X.size(0); + auto n_ = X.size(1); + vector> Xd(n_, vector(m_, 0)); + for (auto i = 0; i < n_; i++) { + auto temp = X.index({ "...", i }); + Xd[i] = vector(temp.data_ptr(), temp.data_ptr() + m_); + } + auto yp = model.predict(Xd); + auto ypred = torch::tensor(yp, torch::kInt64); + return ypred; + } + float Ensemble::score(Tensor& X, Tensor& y) + { + Tensor y_pred = predict(X); + return (y_pred == y).sum().item() / y.size(0); + } + vector Ensemble::show() + { + return model.show(); + } + +} \ No newline at end of file diff --git a/src/Ensemble.h b/src/Ensemble.h new file mode 100644 index 0000000..997e612 --- /dev/null +++ b/src/Ensemble.h @@ -0,0 +1,34 @@ +#ifndef ENSEMBLE_H +#define ENSEMBLE_H +#include +#include "BaseClassifier.h" +#include "Metrics.hpp" +using namespace std; +using namespace torch; + +namespace bayesnet { + class Ensemble { + private: + Ensemble& build(vector& features, string className, map>& states); + protected: + BaseClassifier& model; + vector models; + int m, n; // m: number of samples, n: number of features + Tensor X; + Tensor y; + Tensor dataset; + Metrics metrics; + vector features; + string className; + map> states; + void virtual train() = 0; + public: + Ensemble(BaseClassifier& model); + Ensemble& fit(Tensor& X, Tensor& y, vector& features, string className, map>& states); + Ensemble& fit(vector>& X, vector& y, vector& features, string className, map>& states); + Tensor predict(Tensor& X); + float score(Tensor& X, Tensor& y); + vector show(); + }; +} +#endif