diff --git a/bayesnet/network/CountingSemaphore.h b/bayesnet/network/CountingSemaphore.h new file mode 100644 index 0000000..6f65e71 --- /dev/null +++ b/bayesnet/network/CountingSemaphore.h @@ -0,0 +1,33 @@ +#ifndef COUNTING_SEMAPHORE_H +#define COUNTING_SEMAPHORE_H +#include +#include +class CountingSemaphore { +public: + explicit CountingSemaphore(size_t max_count) : max_count_(max_count), count_(max_count) {} + + // Acquires a permit, blocking if necessary until one becomes available + void acquire() + { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this]() { return count_ > 0; }); + --count_; + } + + // Releases a permit, potentially waking up a blocked acquirer + void release() + { + std::lock_guard lock(mtx_); + ++count_; + if (count_ <= max_count_) { + cv_.notify_one(); + } + } + +private: + std::mutex mtx_; + std::condition_variable cv_; + size_t max_count_; + size_t count_; +}; +#endif \ No newline at end of file diff --git a/bayesnet/network/Network.cc b/bayesnet/network/Network.cc index d034a41..aa9eb16 100644 --- a/bayesnet/network/Network.cc +++ b/bayesnet/network/Network.cc @@ -5,29 +5,25 @@ // *************************************************************** #include -#include -#include #include #include +#include +#include "CountingSemaphore.h" #include "Network.h" #include "bayesnet/utils/bayesnetUtils.h" namespace bayesnet { Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 } { - maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); - if (maxThreadsRunning < 1) { - maxThreadsRunning = 1; - } + maxThreadsRunning = std::max(1, static_cast(std::thread::hardware_concurrency() * maxThreads)); + maxThreadsRunning = std::min(maxThreadsRunning, static_cast(std::thread::hardware_concurrency())); } Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 } { - maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); - if (maxThreadsRunning < 1 || maxT > 1) { - maxThreadsRunning = 1; - } + maxThreadsRunning = std::max(1, static_cast(std::thread::hardware_concurrency() * maxThreads)); + maxThreadsRunning = std::min(maxThreadsRunning, static_cast(std::thread::hardware_concurrency())); } Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()), - maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples) + maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples), maxThreadsRunning(other.maxThreadsRunning) { if (samples.defined()) samples = samples.clone(); @@ -200,21 +196,12 @@ namespace bayesnet { { setStates(states); std::vector threads; - std::mutex mtx; - std::condition_variable cv; - size_t activeThreads = 0; + CountingSemaphore semaphore(maxThreadsRunning); const double n_samples = static_cast(samples.size(1)); - auto worker = [&](std::pair>& node) { - { - std::unique_lock lock(mtx); - cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; }); - ++activeThreads; - } - + semaphore.acquire(); double numStates = static_cast(node.second->getNumStates()); double smoothing_factor = 0.0; - switch (smoothing) { case Smoothing_t::ORIGINAL: smoothing_factor = 1.0 / n_samples; @@ -228,24 +215,15 @@ namespace bayesnet { default: throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast(smoothing))); } - node.second->computeCPT(samples, features, smoothing_factor, weights); - - { - std::lock_guard lock(mtx); - --activeThreads; - } - cv.notify_one(); + semaphore.release(); }; - for (auto& node : nodes) { threads.emplace_back(worker, std::ref(node)); } - for (auto& thread : threads) { thread.join(); } - fitted = true; } torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) @@ -370,32 +348,21 @@ namespace bayesnet { std::vector result(classNumStates, 0.0); std::vector threads; std::mutex mtx; - std::condition_variable cv; - size_t activeThreads = 0; - + CountingSemaphore semaphore(maxThreadsRunning); auto worker = [&](int i) { - { - std::unique_lock lock(mtx); - cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; }); - ++activeThreads; - } - + semaphore.acquire(); auto completeEvidence = std::map(evidence); completeEvidence[getClassName()] = i; double factor = computeFactor(completeEvidence); - { std::lock_guard lock(mtx); result[i] = factor; - --activeThreads; } - cv.notify_one(); + semaphore.release(); }; - for (int i = 0; i < classNumStates; ++i) { threads.emplace_back(worker, i); } - for (auto& thread : threads) { thread.join(); }