Refactor CountingSemaphore as singleton

This commit is contained in:
2024-06-21 09:30:24 +02:00
parent 716748e18c
commit 02bcab01be
5 changed files with 79 additions and 91 deletions

View File

@@ -1,33 +0,0 @@
#ifndef COUNTING_SEMAPHORE_H
#define COUNTING_SEMAPHORE_H
#include <mutex>
#include <condition_variable>
class CountingSemaphore {
public:
explicit CountingSemaphore(size_t max_count) : max_count_(max_count), count_(max_count) {}
// Acquires a permit, blocking if necessary until one becomes available
void acquire()
{
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [this]() { return count_ > 0; });
--count_;
}
// Releases a permit, potentially waking up a blocked acquirer
void release()
{
std::lock_guard<std::mutex> lock(mtx_);
++count_;
if (count_ <= max_count_) {
cv_.notify_one();
}
}
private:
std::mutex mtx_;
std::condition_variable cv_;
size_t max_count_;
size_t count_;
};
#endif

View File

@@ -8,22 +8,16 @@
#include <sstream>
#include <numeric>
#include <algorithm>
#include "CountingSemaphore.h"
#include "Network.h"
#include "bayesnet/utils/bayesnetUtils.h"
#include "bayesnet/utils/CountingSemaphore.h"
#include <pthread.h>
namespace bayesnet {
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }
Network::Network() : fitted{ false }, classNumStates{ 0 }
{
maxThreadsRunning = std::max(1, static_cast<int>(std::thread::hardware_concurrency() * maxThreads));
maxThreadsRunning = std::min(maxThreadsRunning, static_cast<int>(std::thread::hardware_concurrency()));
}
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }
{
maxThreadsRunning = std::max(1, static_cast<int>(std::thread::hardware_concurrency() * maxThreads));
maxThreadsRunning = std::min(maxThreadsRunning, static_cast<int>(std::thread::hardware_concurrency()));
}
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples), maxThreadsRunning(other.maxThreadsRunning)
fitted(other.fitted), samples(other.samples)
{
if (samples.defined())
samples = samples.clone();
@@ -40,10 +34,6 @@ namespace bayesnet {
nodes.clear();
samples = torch::Tensor();
}
float Network::getMaxThreads() const
{
return maxThreads;
}
torch::Tensor& Network::getSamples()
{
return samples;
@@ -196,9 +186,11 @@ namespace bayesnet {
{
setStates(states);
std::vector<std::thread> threads;
CountingSemaphore semaphore(maxThreadsRunning);
auto& semaphore = CountingSemaphore::getInstance();
const double n_samples = static_cast<double>(samples.size(1));
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node) {
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
std::string threadName = "FitWorker-" + std::to_string(i);
pthread_setname_np(pthread_self(), threadName.c_str());
semaphore.acquire();
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
@@ -218,8 +210,9 @@ namespace bayesnet {
node.second->computeCPT(samples, features, smoothing_factor, weights);
semaphore.release();
};
int i = 0;
for (auto& node : nodes) {
threads.emplace_back(worker, std::ref(node));
threads.emplace_back(worker, std::ref(node), i++);
}
for (auto& thread : threads) {
thread.join();
@@ -345,12 +338,21 @@ namespace bayesnet {
}
std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
{
//Implementar una cache para acelerar la inferencia.
// Cambiar la estrategia de crear hilos en la inferencia (por nodos como en fit?)
std::vector<double> result(classNumStates, 0.0);
std::vector<std::thread> threads;
std::mutex mtx;
CountingSemaphore semaphore(maxThreadsRunning);
auto& semaphore = CountingSemaphore::getInstance();
auto worker = [&](int i) {
semaphore.acquire();
std::string threadName = "InferenceWorker-" + std::to_string(i);
pthread_setname_np(pthread_self(), threadName.c_str());
auto completeEvidence = std::map<std::string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);

View File

@@ -56,8 +56,6 @@ namespace bayesnet {
private:
std::map<std::string, std::unique_ptr<Node>> nodes;
bool fitted;
float maxThreads = 0.95; // Coefficient to multiply by the number of threads available
int maxThreadsRunning; // Effective max number of threads running
int classNumStates;
std::vector<std::string> features; // Including classname
std::string className;