// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #include "Ensemble.h" #include "bayesnet/utils/CountingSemaphore.h" namespace bayesnet { Ensemble::Ensemble(bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting) { }; const std::string ENSEMBLE_NOT_FITTED = "Ensemble has not been fitted"; void Ensemble::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) { n_models = models.size(); for (auto i = 0; i < n_models; ++i) { // fit with std::vectors models[i]->fit(dataset, features, className, states, smoothing); } } std::vector Ensemble::compute_arg_max(std::vector>& X) { std::vector y_pred; for (auto i = 0; i < X.size(); ++i) { auto max = std::max_element(X[i].begin(), X[i].end()); y_pred.push_back(std::distance(X[i].begin(), max)); } return y_pred; } torch::Tensor Ensemble::compute_arg_max(torch::Tensor& X) { auto y_pred = torch::argmax(X, 1); return y_pred; } torch::Tensor Ensemble::voting(torch::Tensor& votes) { // Convert m x n_models tensor to a m x n_class_states with voting probabilities auto y_pred_ = votes.accessor(); std::vector y_pred_final; int numClasses = states.at(className).size(); // votes is m x n_models with the prediction of every model for each sample auto result = torch::zeros({ votes.size(0), numClasses }, torch::kFloat32); auto sum = std::reduce(significanceModels.begin(), significanceModels.end()); for (int i = 0; i < votes.size(0); ++i) { // n_votes store in each index (value of class) the significance added by each model // i.e. n_votes[0] contains how much value has the value 0 of class. That value is generated by the models predictions std::vector n_votes(numClasses, 0.0); for (int j = 0; j < n_models; ++j) { n_votes[y_pred_[i][j]] += significanceModels.at(j); } result[i] = torch::tensor(n_votes); } // To only do one division and gain precision result /= sum; return result; } std::vector> Ensemble::predict_proba(std::vector>& X) { if (!fitted) { throw std::logic_error(ENSEMBLE_NOT_FITTED); } return predict_voting ? predict_average_voting(X) : predict_average_proba(X); } torch::Tensor Ensemble::predict_proba(torch::Tensor& X) { if (!fitted) { throw std::logic_error(ENSEMBLE_NOT_FITTED); } return predict_voting ? predict_average_voting(X) : predict_average_proba(X); } std::vector Ensemble::predict(std::vector>& X) { auto res = predict_proba(X); return compute_arg_max(res); } torch::Tensor Ensemble::predict(torch::Tensor& X) { auto res = predict_proba(X); return compute_arg_max(res); } torch::Tensor Ensemble::predict_average_proba(torch::Tensor& X) { auto n_states = models[0]->getClassNumStates(); torch::Tensor y_pred = torch::zeros({ X.size(1), n_states }, torch::kFloat32); for (auto i = 0; i < n_models; ++i) { auto ypredict = models[i]->predict_proba(X); y_pred += ypredict * significanceModels[i]; } auto sum = std::reduce(significanceModels.begin(), significanceModels.end()); y_pred /= sum; return y_pred; } std::vector> Ensemble::predict_average_proba(std::vector>& X) { auto n_states = models[0]->getClassNumStates(); std::vector> y_pred(X[0].size(), std::vector(n_states, 0.0)); for (auto i = 0; i < n_models; ++i) { auto ypredict = models[i]->predict_proba(X); assert(ypredict.size() == y_pred.size()); assert(ypredict[0].size() == y_pred[0].size()); // Multiply each prediction by the significance of the model and then add it to the final prediction for (auto j = 0; j < ypredict.size(); ++j) { std::transform(y_pred[j].begin(), y_pred[j].end(), ypredict[j].begin(), y_pred[j].begin(), [significanceModels = significanceModels[i]](double x, double y) { return x + y * significanceModels; }); } } auto sum = std::reduce(significanceModels.begin(), significanceModels.end()); //Divide each element of the prediction by the sum of the significances for (auto j = 0; j < y_pred.size(); ++j) { std::transform(y_pred[j].begin(), y_pred[j].end(), y_pred[j].begin(), [sum](double x) { return x / sum; }); } return y_pred; } std::vector> Ensemble::predict_average_voting(std::vector>& X) { torch::Tensor Xt = bayesnet::vectorToTensor(X, false); auto y_pred = predict_average_voting(Xt); std::vector> result = tensorToVectorDouble(y_pred); return result; } torch::Tensor Ensemble::predict_average_voting(torch::Tensor& X) { // Build a m x n_models tensor with the predictions of each model torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32); for (auto i = 0; i < n_models; ++i) { auto ypredict = models[i]->predict(X); y_pred.index_put_({ "...", i }, ypredict); } return voting(y_pred); } float Ensemble::score(torch::Tensor& X, torch::Tensor& y) { auto y_pred = predict(X); int correct = 0; for (int i = 0; i < y_pred.size(0); ++i) { if (y_pred[i].item() == y[i].item()) { correct++; } } return (double)correct / y_pred.size(0); } float Ensemble::score(std::vector>& X, std::vector& y) { auto y_pred = predict(X); int correct = 0; for (int i = 0; i < y_pred.size(); ++i) { if (y_pred[i] == y[i]) { correct++; } } return (double)correct / y_pred.size(); } std::vector Ensemble::show() const { auto result = std::vector(); for (auto i = 0; i < n_models; ++i) { auto res = models[i]->show(); result.insert(result.end(), res.begin(), res.end()); } return result; } std::vector Ensemble::graph(const std::string& title) const { auto result = std::vector(); for (auto i = 0; i < n_models; ++i) { auto res = models[i]->graph(title + "_" + std::to_string(i)); result.insert(result.end(), res.begin(), res.end()); } return result; } int Ensemble::getNumberOfNodes() const { int nodes = 0; for (auto i = 0; i < n_models; ++i) { nodes += models[i]->getNumberOfNodes(); } return nodes; } int Ensemble::getNumberOfEdges() const { int edges = 0; for (auto i = 0; i < n_models; ++i) { edges += models[i]->getNumberOfEdges(); } return edges; } int Ensemble::getNumberOfStates() const { int nstates = 0; for (auto i = 0; i < n_models; ++i) { nstates += models[i]->getNumberOfStates(); } return nstates; } }