Max threading

This commit is contained in:
Ricardo Montañana Gómez 2023-07-31 18:49:18 +02:00
parent 43bb017d5d
commit adf650d257
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 16 additions and 7 deletions

View File

@ -68,9 +68,19 @@ namespace bayesnet {
throw logic_error("Ensemble has not been fitted");
}
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
//Create a threadpool
auto threads{ vector<thread>() };
auto lock = mutex();
for (auto i = 0; i < n_models; ++i) {
auto ypredict = models[i]->predict(X);
y_pred.index_put_({ "...", i }, ypredict);
threads.push_back(thread([&, i]() {
auto ypredict = models[i]->predict(X);
lock.lock();
y_pred.index_put_({ "...", i }, ypredict);
lock.unlock();
}));
}
for (auto& thread : threads) {
thread.join();
}
return torch::tensor(voting(y_pred));
}

View File

@ -3,8 +3,8 @@
#include "Network.h"
#include "bayesnetUtils.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {}
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network() : features(vector<string>()), className(""), classNumStates(0), fitted(false) {}
Network::Network(float maxT) : features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted)
{
@ -132,7 +132,6 @@ namespace bayesnet {
}
void Network::completeFit()
{
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;

View File

@ -10,11 +10,11 @@ namespace bayesnet {
map<string, unique_ptr<Node>> nodes;
map<string, vector<int>> dataset;
bool fitted;
float maxThreads;
float maxThreads = 0.95;
int classNumStates;
vector<string> features;
string className;
int laplaceSmoothing;
int laplaceSmoothing = 1;
torch::Tensor samples;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&);