Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "Ensemble.h"
8 :
9 : namespace bayesnet {
10 :
11 864 : Ensemble::Ensemble(bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting)
12 : {
13 :
14 864 : };
15 : const std::string ENSEMBLE_NOT_FITTED = "Ensemble has not been fitted";
16 66 : void Ensemble::trainModel(const torch::Tensor& weights)
17 : {
18 66 : n_models = models.size();
19 517 : for (auto i = 0; i < n_models; ++i) {
20 : // fit with std::vectors
21 451 : models[i]->fit(dataset, features, className, states);
22 : }
23 66 : }
24 145 : std::vector<int> Ensemble::compute_arg_max(std::vector<std::vector<double>>& X)
25 : {
26 145 : std::vector<int> y_pred;
27 33363 : for (auto i = 0; i < X.size(); ++i) {
28 33218 : auto max = std::max_element(X[i].begin(), X[i].end());
29 66436 : y_pred.push_back(std::distance(X[i].begin(), max));
30 : }
31 145 : return y_pred;
32 0 : }
33 933 : torch::Tensor Ensemble::compute_arg_max(torch::Tensor& X)
34 : {
35 933 : auto y_pred = torch::argmax(X, 1);
36 933 : return y_pred;
37 : }
38 291 : torch::Tensor Ensemble::voting(torch::Tensor& votes)
39 : {
40 : // Convert m x n_models tensor to a m x n_class_states with voting probabilities
41 291 : auto y_pred_ = votes.accessor<int, 2>();
42 291 : std::vector<int> y_pred_final;
43 291 : int numClasses = states.at(className).size();
44 : // votes is m x n_models with the prediction of every model for each sample
45 291 : auto result = torch::zeros({ votes.size(0), numClasses }, torch::kFloat32);
46 291 : auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
47 69474 : for (int i = 0; i < votes.size(0); ++i) {
48 : // n_votes store in each index (value of class) the significance added by each model
49 : // i.e. n_votes[0] contains how much value has the value 0 of class. That value is generated by the models predictions
50 69183 : std::vector<double> n_votes(numClasses, 0.0);
51 541708 : for (int j = 0; j < n_models; ++j) {
52 472525 : n_votes[y_pred_[i][j]] += significanceModels.at(j);
53 : }
54 69183 : result[i] = torch::tensor(n_votes);
55 69183 : }
56 : // To only do one division and gain precision
57 291 : result /= sum;
58 582 : return result;
59 291 : }
60 268 : std::vector<std::vector<double>> Ensemble::predict_proba(std::vector<std::vector<int>>& X)
61 : {
62 268 : if (!fitted) {
63 66 : throw std::logic_error(ENSEMBLE_NOT_FITTED);
64 : }
65 202 : return predict_voting ? predict_average_voting(X) : predict_average_proba(X);
66 : }
67 1010 : torch::Tensor Ensemble::predict_proba(torch::Tensor& X)
68 : {
69 1010 : if (!fitted) {
70 66 : throw std::logic_error(ENSEMBLE_NOT_FITTED);
71 : }
72 944 : return predict_voting ? predict_average_voting(X) : predict_average_proba(X);
73 : }
74 178 : std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
75 : {
76 178 : auto res = predict_proba(X);
77 268 : return compute_arg_max(res);
78 134 : }
79 966 : torch::Tensor Ensemble::predict(torch::Tensor& X)
80 : {
81 966 : auto res = predict_proba(X);
82 1844 : return compute_arg_max(res);
83 922 : }
84 735 : torch::Tensor Ensemble::predict_average_proba(torch::Tensor& X)
85 : {
86 735 : auto n_states = models[0]->getClassNumStates();
87 735 : torch::Tensor y_pred = torch::zeros({ X.size(1), n_states }, torch::kFloat32);
88 735 : auto threads{ std::vector<std::thread>() };
89 735 : std::mutex mtx;
90 4253 : for (auto i = 0; i < n_models; ++i) {
91 3518 : threads.push_back(std::thread([&, i]() {
92 3518 : auto ypredict = models[i]->predict_proba(X);
93 3518 : std::lock_guard<std::mutex> lock(mtx);
94 3518 : y_pred += ypredict * significanceModels[i];
95 3518 : }));
96 : }
97 4253 : for (auto& thread : threads) {
98 3518 : thread.join();
99 : }
100 735 : auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
101 735 : y_pred /= sum;
102 1470 : return y_pred;
103 735 : }
104 120 : std::vector<std::vector<double>> Ensemble::predict_average_proba(std::vector<std::vector<int>>& X)
105 : {
106 120 : auto n_states = models[0]->getClassNumStates();
107 120 : std::vector<std::vector<double>> y_pred(X[0].size(), std::vector<double>(n_states, 0.0));
108 120 : auto threads{ std::vector<std::thread>() };
109 120 : std::mutex mtx;
110 842 : for (auto i = 0; i < n_models; ++i) {
111 722 : threads.push_back(std::thread([&, i]() {
112 722 : auto ypredict = models[i]->predict_proba(X);
113 722 : assert(ypredict.size() == y_pred.size());
114 722 : assert(ypredict[0].size() == y_pred[0].size());
115 722 : std::lock_guard<std::mutex> lock(mtx);
116 : // Multiply each prediction by the significance of the model and then add it to the final prediction
117 143118 : for (auto j = 0; j < ypredict.size(); ++j) {
118 142396 : std::transform(y_pred[j].begin(), y_pred[j].end(), ypredict[j].begin(), y_pred[j].begin(),
119 898532 : [significanceModels = significanceModels[i]](double x, double y) { return x + y * significanceModels; });
120 : }
121 722 : }));
122 : }
123 842 : for (auto& thread : threads) {
124 722 : thread.join();
125 : }
126 120 : auto sum = std::reduce(significanceModels.begin(), significanceModels.end());
127 : //Divide each element of the prediction by the sum of the significances
128 22520 : for (auto j = 0; j < y_pred.size(); ++j) {
129 120660 : std::transform(y_pred[j].begin(), y_pred[j].end(), y_pred[j].begin(), [sum](double x) { return x / sum; });
130 : }
131 240 : return y_pred;
132 120 : }
133 82 : std::vector<std::vector<double>> Ensemble::predict_average_voting(std::vector<std::vector<int>>& X)
134 : {
135 82 : torch::Tensor Xt = bayesnet::vectorToTensor(X, false);
136 82 : auto y_pred = predict_average_voting(Xt);
137 82 : std::vector<std::vector<double>> result = tensorToVectorDouble(y_pred);
138 164 : return result;
139 82 : }
140 291 : torch::Tensor Ensemble::predict_average_voting(torch::Tensor& X)
141 : {
142 : // Build a m x n_models tensor with the predictions of each model
143 291 : torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32);
144 291 : auto threads{ std::vector<std::thread>() };
145 291 : std::mutex mtx;
146 1959 : for (auto i = 0; i < n_models; ++i) {
147 1668 : threads.push_back(std::thread([&, i]() {
148 1668 : auto ypredict = models[i]->predict(X);
149 1668 : std::lock_guard<std::mutex> lock(mtx);
150 5004 : y_pred.index_put_({ "...", i }, ypredict);
151 3336 : }));
152 : }
153 1959 : for (auto& thread : threads) {
154 1668 : thread.join();
155 : }
156 582 : return voting(y_pred);
157 291 : }
158 194 : float Ensemble::score(torch::Tensor& X, torch::Tensor& y)
159 : {
160 194 : auto y_pred = predict(X);
161 172 : int correct = 0;
162 53601 : for (int i = 0; i < y_pred.size(0); ++i) {
163 53429 : if (y_pred[i].item<int>() == y[i].item<int>()) {
164 45279 : correct++;
165 : }
166 : }
167 344 : return (double)correct / y_pred.size(0);
168 172 : }
169 134 : float Ensemble::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
170 : {
171 134 : auto y_pred = predict(X);
172 112 : int correct = 0;
173 29964 : for (int i = 0; i < y_pred.size(); ++i) {
174 29852 : if (y_pred[i] == y[i]) {
175 25423 : correct++;
176 : }
177 : }
178 224 : return (double)correct / y_pred.size();
179 112 : }
180 11 : std::vector<std::string> Ensemble::show() const
181 : {
182 11 : auto result = std::vector<std::string>();
183 55 : for (auto i = 0; i < n_models; ++i) {
184 44 : auto res = models[i]->show();
185 44 : result.insert(result.end(), res.begin(), res.end());
186 44 : }
187 11 : return result;
188 0 : }
189 33 : std::vector<std::string> Ensemble::graph(const std::string& title) const
190 : {
191 33 : auto result = std::vector<std::string>();
192 220 : for (auto i = 0; i < n_models; ++i) {
193 187 : auto res = models[i]->graph(title + "_" + std::to_string(i));
194 187 : result.insert(result.end(), res.begin(), res.end());
195 187 : }
196 33 : return result;
197 0 : }
198 70 : int Ensemble::getNumberOfNodes() const
199 : {
200 70 : int nodes = 0;
201 512 : for (auto i = 0; i < n_models; ++i) {
202 442 : nodes += models[i]->getNumberOfNodes();
203 : }
204 70 : return nodes;
205 : }
206 70 : int Ensemble::getNumberOfEdges() const
207 : {
208 70 : int edges = 0;
209 512 : for (auto i = 0; i < n_models; ++i) {
210 442 : edges += models[i]->getNumberOfEdges();
211 : }
212 70 : return edges;
213 : }
214 11 : int Ensemble::getNumberOfStates() const
215 : {
216 11 : int nstates = 0;
217 55 : for (auto i = 0; i < n_models; ++i) {
218 44 : nstates += models[i]->getNumberOfStates();
219 : }
220 11 : return nstates;
221 : }
222 : }
|