Add Counting Semaphore class
Fix threading in Network
This commit is contained in:
parent
0b31780d39
commit
716748e18c
33
bayesnet/network/CountingSemaphore.h
Normal file
33
bayesnet/network/CountingSemaphore.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#ifndef COUNTING_SEMAPHORE_H
|
||||||
|
#define COUNTING_SEMAPHORE_H
|
||||||
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
|
class CountingSemaphore {
|
||||||
|
public:
|
||||||
|
explicit CountingSemaphore(size_t max_count) : max_count_(max_count), count_(max_count) {}
|
||||||
|
|
||||||
|
// Acquires a permit, blocking if necessary until one becomes available
|
||||||
|
void acquire()
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
cv_.wait(lock, [this]() { return count_ > 0; });
|
||||||
|
--count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases a permit, potentially waking up a blocked acquirer
|
||||||
|
void release()
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mtx_);
|
||||||
|
++count_;
|
||||||
|
if (count_ <= max_count_) {
|
||||||
|
cv_.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
size_t max_count_;
|
||||||
|
size_t count_;
|
||||||
|
};
|
||||||
|
#endif
|
@ -5,29 +5,25 @@
|
|||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
|
||||||
#include <semaphore>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "CountingSemaphore.h"
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
#include "bayesnet/utils/bayesnetUtils.h"
|
#include "bayesnet/utils/bayesnetUtils.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }
|
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }
|
||||||
{
|
{
|
||||||
maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
maxThreadsRunning = std::max(1, static_cast<int>(std::thread::hardware_concurrency() * maxThreads));
|
||||||
if (maxThreadsRunning < 1) {
|
maxThreadsRunning = std::min(maxThreadsRunning, static_cast<int>(std::thread::hardware_concurrency()));
|
||||||
maxThreadsRunning = 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }
|
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }
|
||||||
{
|
{
|
||||||
maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
maxThreadsRunning = std::max(1, static_cast<int>(std::thread::hardware_concurrency() * maxThreads));
|
||||||
if (maxThreadsRunning < 1 || maxT > 1) {
|
maxThreadsRunning = std::min(maxThreadsRunning, static_cast<int>(std::thread::hardware_concurrency()));
|
||||||
maxThreadsRunning = 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
|
||||||
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
|
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples), maxThreadsRunning(other.maxThreadsRunning)
|
||||||
{
|
{
|
||||||
if (samples.defined())
|
if (samples.defined())
|
||||||
samples = samples.clone();
|
samples = samples.clone();
|
||||||
@ -200,21 +196,12 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
setStates(states);
|
setStates(states);
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
std::mutex mtx;
|
CountingSemaphore semaphore(maxThreadsRunning);
|
||||||
std::condition_variable cv;
|
|
||||||
size_t activeThreads = 0;
|
|
||||||
const double n_samples = static_cast<double>(samples.size(1));
|
const double n_samples = static_cast<double>(samples.size(1));
|
||||||
|
|
||||||
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node) {
|
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node) {
|
||||||
{
|
semaphore.acquire();
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; });
|
|
||||||
++activeThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
double numStates = static_cast<double>(node.second->getNumStates());
|
double numStates = static_cast<double>(node.second->getNumStates());
|
||||||
double smoothing_factor = 0.0;
|
double smoothing_factor = 0.0;
|
||||||
|
|
||||||
switch (smoothing) {
|
switch (smoothing) {
|
||||||
case Smoothing_t::ORIGINAL:
|
case Smoothing_t::ORIGINAL:
|
||||||
smoothing_factor = 1.0 / n_samples;
|
smoothing_factor = 1.0 / n_samples;
|
||||||
@ -228,24 +215,15 @@ namespace bayesnet {
|
|||||||
default:
|
default:
|
||||||
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
|
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
|
||||||
}
|
}
|
||||||
|
|
||||||
node.second->computeCPT(samples, features, smoothing_factor, weights);
|
node.second->computeCPT(samples, features, smoothing_factor, weights);
|
||||||
|
semaphore.release();
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
|
||||||
--activeThreads;
|
|
||||||
}
|
|
||||||
cv.notify_one();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
for (auto& node : nodes) {
|
for (auto& node : nodes) {
|
||||||
threads.emplace_back(worker, std::ref(node));
|
threads.emplace_back(worker, std::ref(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
fitted = true;
|
fitted = true;
|
||||||
}
|
}
|
||||||
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
|
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
|
||||||
@ -370,32 +348,21 @@ namespace bayesnet {
|
|||||||
std::vector<double> result(classNumStates, 0.0);
|
std::vector<double> result(classNumStates, 0.0);
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
std::mutex mtx;
|
std::mutex mtx;
|
||||||
std::condition_variable cv;
|
CountingSemaphore semaphore(maxThreadsRunning);
|
||||||
size_t activeThreads = 0;
|
|
||||||
|
|
||||||
auto worker = [&](int i) {
|
auto worker = [&](int i) {
|
||||||
{
|
semaphore.acquire();
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; });
|
|
||||||
++activeThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto completeEvidence = std::map<std::string, int>(evidence);
|
auto completeEvidence = std::map<std::string, int>(evidence);
|
||||||
completeEvidence[getClassName()] = i;
|
completeEvidence[getClassName()] = i;
|
||||||
double factor = computeFactor(completeEvidence);
|
double factor = computeFactor(completeEvidence);
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
result[i] = factor;
|
result[i] = factor;
|
||||||
--activeThreads;
|
|
||||||
}
|
}
|
||||||
cv.notify_one();
|
semaphore.release();
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int i = 0; i < classNumStates; ++i) {
|
for (int i = 0; i < classNumStates; ++i) {
|
||||||
threads.emplace_back(worker, i);
|
threads.emplace_back(worker, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
thread.join();
|
thread.join();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user