Fix number of threads spawned

This commit is contained in:
Ricardo Montañana Gómez 2024-06-21 19:56:35 +02:00
parent 8e9090d283
commit 59c1cf5b3b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -190,8 +190,11 @@ namespace bayesnet {
const double n_samples = static_cast<double>(samples.size(1));
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
std::string threadName = "FitWorker-" + std::to_string(i);
#if defined(__linux__)
pthread_setname_np(pthread_self(), threadName.c_str());
semaphore.acquire();
#else
pthread_setname_np(threadName.c_str());
#endif
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
switch (smoothing) {
@ -212,6 +215,7 @@ namespace bayesnet {
};
int i = 0;
for (auto& node : nodes) {
semaphore.acquire();
threads.emplace_back(worker, std::ref(node), i++);
}
for (auto& thread : threads) {
@ -236,8 +240,11 @@ namespace bayesnet {
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
auto worker = [&](const torch::Tensor& sample, int i) {
std::string threadName = "PredictWorker-" + std::to_string(i);
#if defined(__linux__)
pthread_setname_np(pthread_self(), threadName.c_str());
semaphore.acquire();
#else
pthread_setname_np(threadName.c_str());
#endif
auto psample = predict_sample(sample);
auto temp = torch::tensor(psample, torch::kFloat64);
{
@ -247,6 +254,7 @@ namespace bayesnet {
semaphore.release();
};
for (int i = 0; i < samples.size(1); ++i) {
semaphore.acquire();
const torch::Tensor sample = samples.index({ "...", i });
threads.emplace_back(worker, sample, i);
}