Add predict_voting and predict_prob to ensemble

This commit is contained in:
Ricardo Montañana Gómez 2024-02-20 10:58:21 +01:00
parent a63a35df3f
commit e1c4221c11
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 64 additions and 24 deletions

View File

@ -2,8 +2,8 @@
namespace bayesnet {
Ensemble::Ensemble() : Classifier(Network()), n_models(0) {}
Ensemble::Ensemble(bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting) {};
const std::string ENSEMBLE_NOT_FITTED = "Ensemble has not been fitted";
void Ensemble::trainModel(const torch::Tensor& weights)
{
n_models = models.size();
@ -12,6 +12,7 @@ namespace bayesnet {
models[i]->fit(dataset, features, className, states);
}
}
std::vector<int> Ensemble::voting(torch::Tensor& y_pred)
{
auto y_pred_ = y_pred.accessor<int, 2>();
@ -31,11 +32,55 @@ namespace bayesnet {
}
return y_pred_final;
}
std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
{
if (!fitted) {
throw std::logic_error(ENSEMBLE_NOT_FITTED);
}
return predict_voting ? do_predict_voting(X) : do_predict_prob(X);
}
torch::Tensor Ensemble::predict(torch::Tensor& X)
{
if (!fitted) {
throw std::logic_error("Ensemble has not been fitted");
throw std::logic_error(ENSEMBLE_NOT_FITTED);
}
return predict_voting ? do_predict_voting(X) : do_predict_prob(X);
}
torch::Tensor Ensemble::do_predict_prob(torch::Tensor& X)
{
torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kFloat32);
// auto threads{ std::vector<std::thread>() };
// std::mutex mtx;
// for (auto i = 0; i < n_models; ++i) {
// threads.push_back(std::thread([&, i]() {
// auto ypredict = models[i]->predict(X);
// std::lock_guard<std::mutex> lock(mtx);
// y_pred.index_put_({ "...", i }, ypredict);
// }));
// }
// for (auto& thread : threads) {
// thread.join();
// }
return y_pred;
}
std::vector<int> Ensemble::do_predict_prob(std::vector<std::vector<int>>& X)
{
// long m_ = X[0].size();
// long n_ = X.size();
// vector<vector<int>> Xd(n_, vector<int>(m_, 0));
// for (auto i = 0; i < n_; i++) {
// Xd[i] = vector<int>(X[i].begin(), X[i].end());
// }
// torch::Tensor y_pred = torch::zeros({ m_, n_models }, torch::kInt32);
// for (auto i = 0; i < n_models; ++i) {
// y_pred.index_put_({ "...", i }, torch::tensor(models[i]->predict(Xd), torch::kInt32));
// }
// return voting(y_pred);
return std::vector<int>();
}
torch::Tensor Ensemble::do_predict_voting(torch::Tensor& X)
{
torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32);
auto threads{ std::vector<std::thread>() };
std::mutex mtx;
@ -51,11 +96,8 @@ namespace bayesnet {
}
return torch::tensor(voting(y_pred));
}
std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
std::vector<int> Ensemble::do_predict_voting(std::vector<std::vector<int>>& X)
{
if (!fitted) {
throw std::logic_error("Ensemble has not been fitted");
}
long m_ = X[0].size();
long n_ = X.size();
std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
@ -70,10 +112,7 @@ namespace bayesnet {
}
float Ensemble::score(torch::Tensor& X, torch::Tensor& y)
{
if (!fitted) {
throw std::logic_error("Ensemble has not been fitted");
}
auto y_pred = predict(X);
auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X);
int correct = 0;
for (int i = 0; i < y_pred.size(0); ++i) {
if (y_pred[i].item<int>() == y[i].item<int>()) {
@ -84,10 +123,7 @@ namespace bayesnet {
}
float Ensemble::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
{
if (!fitted) {
throw std::logic_error("Ensemble has not been fitted");
}
auto y_pred = predict(X);
auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {
if (y_pred[i] == y[i]) {

View File

@ -7,19 +7,15 @@
namespace bayesnet {
class Ensemble : public Classifier {
private:
Ensemble& build(std::vector<std::string>& features, std::string className, std::map<std::string, std::vector<int>>& states);
protected:
unsigned n_models;
std::vector<std::unique_ptr<Classifier>> models;
std::vector<double> significanceModels;
void trainModel(const torch::Tensor& weights) override;
std::vector<int> voting(torch::Tensor& y_pred);
public:
Ensemble();
Ensemble(bool predict_voting = true);
virtual ~Ensemble() = default;
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor do_predict_voting(torch::Tensor& X);
std::vector<int> do_predict_voting(std::vector<std::vector<int>>& X);
torch::Tensor do_predict_prob(torch::Tensor& X);
std::vector<int> do_predict_prob(std::vector<std::vector<int>>& X);
float score(torch::Tensor& X, torch::Tensor& y) override;
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
int getNumberOfNodes() const override;
@ -34,6 +30,14 @@ namespace bayesnet {
void dump_cpt() const override
{
}
protected:
unsigned n_models;
std::vector<std::unique_ptr<Classifier>> models;
std::vector<double> significanceModels;
void trainModel(const torch::Tensor& weights) override;
std::vector<int> voting(torch::Tensor& y_pred);
private:
bool predict_voting;
};
}
#endif