Fix number of threads spawned
This commit is contained in:
parent
8e9090d283
commit
59c1cf5b3b
@ -190,8 +190,11 @@ namespace bayesnet {
|
|||||||
const double n_samples = static_cast<double>(samples.size(1));
|
const double n_samples = static_cast<double>(samples.size(1));
|
||||||
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
|
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
|
||||||
std::string threadName = "FitWorker-" + std::to_string(i);
|
std::string threadName = "FitWorker-" + std::to_string(i);
|
||||||
|
#if defined(__linux__)
|
||||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||||
semaphore.acquire();
|
#else
|
||||||
|
pthread_setname_np(threadName.c_str());
|
||||||
|
#endif
|
||||||
double numStates = static_cast<double>(node.second->getNumStates());
|
double numStates = static_cast<double>(node.second->getNumStates());
|
||||||
double smoothing_factor = 0.0;
|
double smoothing_factor = 0.0;
|
||||||
switch (smoothing) {
|
switch (smoothing) {
|
||||||
@ -212,6 +215,7 @@ namespace bayesnet {
|
|||||||
};
|
};
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto& node : nodes) {
|
for (auto& node : nodes) {
|
||||||
|
semaphore.acquire();
|
||||||
threads.emplace_back(worker, std::ref(node), i++);
|
threads.emplace_back(worker, std::ref(node), i++);
|
||||||
}
|
}
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
@ -236,8 +240,11 @@ namespace bayesnet {
|
|||||||
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
|
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
|
||||||
auto worker = [&](const torch::Tensor& sample, int i) {
|
auto worker = [&](const torch::Tensor& sample, int i) {
|
||||||
std::string threadName = "PredictWorker-" + std::to_string(i);
|
std::string threadName = "PredictWorker-" + std::to_string(i);
|
||||||
|
#if defined(__linux__)
|
||||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||||
semaphore.acquire();
|
#else
|
||||||
|
pthread_setname_np(threadName.c_str());
|
||||||
|
#endif
|
||||||
auto psample = predict_sample(sample);
|
auto psample = predict_sample(sample);
|
||||||
auto temp = torch::tensor(psample, torch::kFloat64);
|
auto temp = torch::tensor(psample, torch::kFloat64);
|
||||||
{
|
{
|
||||||
@ -247,6 +254,7 @@ namespace bayesnet {
|
|||||||
semaphore.release();
|
semaphore.release();
|
||||||
};
|
};
|
||||||
for (int i = 0; i < samples.size(1); ++i) {
|
for (int i = 0; i < samples.size(1); ++i) {
|
||||||
|
semaphore.acquire();
|
||||||
const torch::Tensor sample = samples.index({ "...", i });
|
const torch::Tensor sample = samples.index({ "...", i });
|
||||||
threads.emplace_back(worker, sample, i);
|
threads.emplace_back(worker, sample, i);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user