Max threading

This commit is contained in:
Ricardo Montañana Gómez 2023-07-31 18:49:18 +02:00
parent 43bb017d5d
commit adf650d257
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 16 additions and 7 deletions

View File

@ -68,9 +68,19 @@ namespace bayesnet {
throw logic_error("Ensemble has not been fitted"); throw logic_error("Ensemble has not been fitted");
} }
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32); Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
//Create a threadpool
auto threads{ vector<thread>() };
auto lock = mutex();
for (auto i = 0; i < n_models; ++i) { for (auto i = 0; i < n_models; ++i) {
threads.push_back(thread([&, i]() {
auto ypredict = models[i]->predict(X); auto ypredict = models[i]->predict(X);
lock.lock();
y_pred.index_put_({ "...", i }, ypredict); y_pred.index_put_({ "...", i }, ypredict);
lock.unlock();
}));
}
for (auto& thread : threads) {
thread.join();
} }
return torch::tensor(voting(y_pred)); return torch::tensor(voting(y_pred));
} }

View File

@ -3,8 +3,8 @@
#include "Network.h" #include "Network.h"
#include "bayesnetUtils.h" #include "bayesnetUtils.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {} Network::Network() : features(vector<string>()), className(""), classNumStates(0), fitted(false) {}
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {} Network::Network(float maxT) : features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {} Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted) Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted)
{ {
@ -132,7 +132,6 @@ namespace bayesnet {
} }
void Network::completeFit() void Network::completeFit()
{ {
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads); int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) { if (maxThreadsRunning < 1) {
maxThreadsRunning = 1; maxThreadsRunning = 1;

View File

@ -10,11 +10,11 @@ namespace bayesnet {
map<string, unique_ptr<Node>> nodes; map<string, unique_ptr<Node>> nodes;
map<string, vector<int>> dataset; map<string, vector<int>> dataset;
bool fitted; bool fitted;
float maxThreads; float maxThreads = 0.95;
int classNumStates; int classNumStates;
vector<string> features; vector<string> features;
string className; string className;
int laplaceSmoothing; int laplaceSmoothing = 1;
torch::Tensor samples; torch::Tensor samples;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&); bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&); vector<double> predict_sample(const vector<int>&);