Max threading
This commit is contained in:
parent
43bb017d5d
commit
adf650d257
@ -68,9 +68,19 @@ namespace bayesnet {
|
||||
throw logic_error("Ensemble has not been fitted");
|
||||
}
|
||||
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
|
||||
//Create a threadpool
|
||||
auto threads{ vector<thread>() };
|
||||
auto lock = mutex();
|
||||
for (auto i = 0; i < n_models; ++i) {
|
||||
auto ypredict = models[i]->predict(X);
|
||||
y_pred.index_put_({ "...", i }, ypredict);
|
||||
threads.push_back(thread([&, i]() {
|
||||
auto ypredict = models[i]->predict(X);
|
||||
lock.lock();
|
||||
y_pred.index_put_({ "...", i }, ypredict);
|
||||
lock.unlock();
|
||||
}));
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
return torch::tensor(voting(y_pred));
|
||||
}
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include "Network.h"
|
||||
#include "bayesnetUtils.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8), fitted(false) {}
|
||||
Network::Network(float maxT) : laplaceSmoothing(1), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
|
||||
Network::Network() : features(vector<string>()), className(""), classNumStates(0), fitted(false) {}
|
||||
Network::Network(float maxT) : features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
|
||||
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT), fitted(false) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()), fitted(other.fitted)
|
||||
{
|
||||
@ -132,7 +132,6 @@ namespace bayesnet {
|
||||
}
|
||||
void Network::completeFit()
|
||||
{
|
||||
|
||||
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
||||
if (maxThreadsRunning < 1) {
|
||||
maxThreadsRunning = 1;
|
||||
|
@ -10,11 +10,11 @@ namespace bayesnet {
|
||||
map<string, unique_ptr<Node>> nodes;
|
||||
map<string, vector<int>> dataset;
|
||||
bool fitted;
|
||||
float maxThreads;
|
||||
float maxThreads = 0.95;
|
||||
int classNumStates;
|
||||
vector<string> features;
|
||||
string className;
|
||||
int laplaceSmoothing;
|
||||
int laplaceSmoothing = 1;
|
||||
torch::Tensor samples;
|
||||
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
|
||||
vector<double> predict_sample(const vector<int>&);
|
||||
|
Loading…
Reference in New Issue
Block a user