diff --git a/src/BayesNet/Ensemble.cc b/src/BayesNet/Ensemble.cc index 3e4d1a6..c18b51e 100644 --- a/src/BayesNet/Ensemble.cc +++ b/src/BayesNet/Ensemble.cc @@ -68,9 +68,19 @@ namespace bayesnet { throw logic_error("Ensemble has not been fitted"); } Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32); + //Create a threadpool + auto threads{ vector() }; + auto lock = mutex(); for (auto i = 0; i < n_models; ++i) { - auto ypredict = models[i]->predict(X); - y_pred.index_put_({ "...", i }, ypredict); + threads.push_back(thread([&, i]() { + auto ypredict = models[i]->predict(X); + lock.lock(); + y_pred.index_put_({ "...", i }, ypredict); + lock.unlock(); + })); + } + for (auto& thread : threads) { + thread.join(); } return torch::tensor(voting(y_pred)); } diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index f46e1e9..5f9aa93 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -3,8 +3,8 @@ #include "Network.h" #include "bayesnetUtils.h" namespace bayesnet { - Network::Network() : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {} - Network::Network(float maxT) : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {} + Network::Network() : features(vector()), className(""), classNumStates(0), fitted(false) {} + Network::Network(float maxT) : features(vector()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {} Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {} Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted) { @@ -132,7 +132,6 @@ namespace bayesnet { } void Network::completeFit() { - int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); if (maxThreadsRunning < 1) { maxThreadsRunning = 1; diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index f763dde..203daa7 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -10,11 +10,11 @@ namespace bayesnet { map> nodes; map> dataset; bool fitted; - float maxThreads; + float maxThreads = 0.95; int classNumStates; vector features; string className; - int laplaceSmoothing; + int laplaceSmoothing = 1; torch::Tensor samples; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); vector predict_sample(const vector&);