Add predict_voting and predict_prob to ensemble
This commit is contained in:
parent
a63a35df3f
commit
e1c4221c11
@ -2,8 +2,8 @@
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
Ensemble::Ensemble() : Classifier(Network()), n_models(0) {}
|
||||
|
||||
Ensemble::Ensemble(bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting) {};
|
||||
const std::string ENSEMBLE_NOT_FITTED = "Ensemble has not been fitted";
|
||||
void Ensemble::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
n_models = models.size();
|
||||
@ -12,6 +12,7 @@ namespace bayesnet {
|
||||
models[i]->fit(dataset, features, className, states);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> Ensemble::voting(torch::Tensor& y_pred)
|
||||
{
|
||||
auto y_pred_ = y_pred.accessor<int, 2>();
|
||||
@ -31,11 +32,55 @@ namespace bayesnet {
|
||||
}
|
||||
return y_pred_final;
|
||||
}
|
||||
std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error(ENSEMBLE_NOT_FITTED);
|
||||
}
|
||||
return predict_voting ? do_predict_voting(X) : do_predict_prob(X);
|
||||
|
||||
}
|
||||
torch::Tensor Ensemble::predict(torch::Tensor& X)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error("Ensemble has not been fitted");
|
||||
throw std::logic_error(ENSEMBLE_NOT_FITTED);
|
||||
}
|
||||
return predict_voting ? do_predict_voting(X) : do_predict_prob(X);
|
||||
}
|
||||
torch::Tensor Ensemble::do_predict_prob(torch::Tensor& X)
|
||||
{
|
||||
torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kFloat32);
|
||||
// auto threads{ std::vector<std::thread>() };
|
||||
// std::mutex mtx;
|
||||
// for (auto i = 0; i < n_models; ++i) {
|
||||
// threads.push_back(std::thread([&, i]() {
|
||||
// auto ypredict = models[i]->predict(X);
|
||||
// std::lock_guard<std::mutex> lock(mtx);
|
||||
// y_pred.index_put_({ "...", i }, ypredict);
|
||||
// }));
|
||||
// }
|
||||
// for (auto& thread : threads) {
|
||||
// thread.join();
|
||||
// }
|
||||
return y_pred;
|
||||
}
|
||||
std::vector<int> Ensemble::do_predict_prob(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
// long m_ = X[0].size();
|
||||
// long n_ = X.size();
|
||||
// vector<vector<int>> Xd(n_, vector<int>(m_, 0));
|
||||
// for (auto i = 0; i < n_; i++) {
|
||||
// Xd[i] = vector<int>(X[i].begin(), X[i].end());
|
||||
// }
|
||||
// torch::Tensor y_pred = torch::zeros({ m_, n_models }, torch::kInt32);
|
||||
// for (auto i = 0; i < n_models; ++i) {
|
||||
// y_pred.index_put_({ "...", i }, torch::tensor(models[i]->predict(Xd), torch::kInt32));
|
||||
// }
|
||||
// return voting(y_pred);
|
||||
return std::vector<int>();
|
||||
}
|
||||
torch::Tensor Ensemble::do_predict_voting(torch::Tensor& X)
|
||||
{
|
||||
torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32);
|
||||
auto threads{ std::vector<std::thread>() };
|
||||
std::mutex mtx;
|
||||
@ -51,11 +96,8 @@ namespace bayesnet {
|
||||
}
|
||||
return torch::tensor(voting(y_pred));
|
||||
}
|
||||
std::vector<int> Ensemble::predict(std::vector<std::vector<int>>& X)
|
||||
std::vector<int> Ensemble::do_predict_voting(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error("Ensemble has not been fitted");
|
||||
}
|
||||
long m_ = X[0].size();
|
||||
long n_ = X.size();
|
||||
std::vector<std::vector<int>> Xd(n_, std::vector<int>(m_, 0));
|
||||
@ -70,10 +112,7 @@ namespace bayesnet {
|
||||
}
|
||||
float Ensemble::score(torch::Tensor& X, torch::Tensor& y)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error("Ensemble has not been fitted");
|
||||
}
|
||||
auto y_pred = predict(X);
|
||||
auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X);
|
||||
int correct = 0;
|
||||
for (int i = 0; i < y_pred.size(0); ++i) {
|
||||
if (y_pred[i].item<int>() == y[i].item<int>()) {
|
||||
@ -84,10 +123,7 @@ namespace bayesnet {
|
||||
}
|
||||
float Ensemble::score(std::vector<std::vector<int>>& X, std::vector<int>& y)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error("Ensemble has not been fitted");
|
||||
}
|
||||
auto y_pred = predict(X);
|
||||
auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X);
|
||||
int correct = 0;
|
||||
for (int i = 0; i < y_pred.size(); ++i) {
|
||||
if (y_pred[i] == y[i]) {
|
||||
|
@ -7,19 +7,15 @@
|
||||
|
||||
namespace bayesnet {
|
||||
class Ensemble : public Classifier {
|
||||
private:
|
||||
Ensemble& build(std::vector<std::string>& features, std::string className, std::map<std::string, std::vector<int>>& states);
|
||||
protected:
|
||||
unsigned n_models;
|
||||
std::vector<std::unique_ptr<Classifier>> models;
|
||||
std::vector<double> significanceModels;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
std::vector<int> voting(torch::Tensor& y_pred);
|
||||
public:
|
||||
Ensemble();
|
||||
Ensemble(bool predict_voting = true);
|
||||
virtual ~Ensemble() = default;
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||
torch::Tensor do_predict_voting(torch::Tensor& X);
|
||||
std::vector<int> do_predict_voting(std::vector<std::vector<int>>& X);
|
||||
torch::Tensor do_predict_prob(torch::Tensor& X);
|
||||
std::vector<int> do_predict_prob(std::vector<std::vector<int>>& X);
|
||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
|
||||
int getNumberOfNodes() const override;
|
||||
@ -34,6 +30,14 @@ namespace bayesnet {
|
||||
void dump_cpt() const override
|
||||
{
|
||||
}
|
||||
protected:
|
||||
unsigned n_models;
|
||||
std::vector<std::unique_ptr<Classifier>> models;
|
||||
std::vector<double> significanceModels;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
std::vector<int> voting(torch::Tensor& y_pred);
|
||||
private:
|
||||
bool predict_voting;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user