diff --git a/bayesnet/network/Network.cc b/bayesnet/network/Network.cc index b19a054..d034a41 100644 --- a/bayesnet/network/Network.cc +++ b/bayesnet/network/Network.cc @@ -6,6 +6,7 @@ #include #include +#include #include #include #include "Network.h" @@ -13,10 +14,17 @@ namespace bayesnet { Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 } { + maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); + if (maxThreadsRunning < 1) { + maxThreadsRunning = 1; + } } Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 } { - + maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); + if (maxThreadsRunning < 1 || maxT > 1) { + maxThreadsRunning = 1; + } } Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples) @@ -192,30 +200,52 @@ namespace bayesnet { { setStates(states); std::vector threads; + std::mutex mtx; + std::condition_variable cv; + size_t activeThreads = 0; const double n_samples = static_cast(samples.size(1)); + + auto worker = [&](std::pair>& node) { + { + std::unique_lock lock(mtx); + cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; }); + ++activeThreads; + } + + double numStates = static_cast(node.second->getNumStates()); + double smoothing_factor = 0.0; + + switch (smoothing) { + case Smoothing_t::ORIGINAL: + smoothing_factor = 1.0 / n_samples; + break; + case Smoothing_t::LAPLACE: + smoothing_factor = 1.0; + break; + case Smoothing_t::CESTNIK: + smoothing_factor = 1 / numStates; + break; + default: + throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast(smoothing))); + } + + node.second->computeCPT(samples, features, smoothing_factor, weights); + + { + std::lock_guard lock(mtx); + --activeThreads; + } + cv.notify_one(); + }; + for (auto& node : nodes) { - threads.emplace_back([this, &node, &weights, n_samples, smoothing]() { - double numStates = static_cast(node.second->getNumStates()); - double smoothing_factor = 0.0; - switch (smoothing) { - case Smoothing_t::ORIGINAL: - smoothing_factor = 1.0 / n_samples; - break; - case Smoothing_t::LAPLACE: - smoothing_factor = 1.0; - break; - case Smoothing_t::CESTNIK: // Considering m=1 pa = 1/numStates - smoothing_factor = 1 / numStates; - break; - default: - throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast(smoothing))); - } - node.second->computeCPT(samples, features, smoothing_factor, weights); - }); + threads.emplace_back(worker, std::ref(node)); } + for (auto& thread : threads) { thread.join(); } + fitted = true; } torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) @@ -340,15 +370,32 @@ namespace bayesnet { std::vector result(classNumStates, 0.0); std::vector threads; std::mutex mtx; - for (int i = 0; i < classNumStates; ++i) { - threads.emplace_back([this, &result, &evidence, i, &mtx]() { - auto completeEvidence = std::map(evidence); - completeEvidence[getClassName()] = i; - double factor = computeFactor(completeEvidence); + std::condition_variable cv; + size_t activeThreads = 0; + + auto worker = [&](int i) { + { + std::unique_lock lock(mtx); + cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; }); + ++activeThreads; + } + + auto completeEvidence = std::map(evidence); + completeEvidence[getClassName()] = i; + double factor = computeFactor(completeEvidence); + + { std::lock_guard lock(mtx); result[i] = factor; - }); + --activeThreads; + } + cv.notify_one(); + }; + + for (int i = 0; i < classNumStates; ++i) { + threads.emplace_back(worker, i); } + for (auto& thread : threads) { thread.join(); } diff --git a/bayesnet/network/Network.h b/bayesnet/network/Network.h index 3485e64..a14540d 100644 --- a/bayesnet/network/Network.h +++ b/bayesnet/network/Network.h @@ -56,7 +56,8 @@ namespace bayesnet { private: std::map> nodes; bool fitted; - float maxThreads = 0.95; + float maxThreads = 0.95; // Coefficient to multiply by the number of threads available + int maxThreadsRunning; // Effective max number of threads running int classNumStates; std::vector features; // Including classname std::string className;