Add Thread max spawning to Network

This commit is contained in:
Ricardo Montañana Gómez 2024-06-18 23:18:24 +02:00
parent fa26aa80f7
commit 0b31780d39
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 74 additions and 26 deletions

View File

@ -6,6 +6,7 @@
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <semaphore>
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#include "Network.h" #include "Network.h"
@ -13,10 +14,17 @@
namespace bayesnet { namespace bayesnet {
Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 } Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }
{ {
maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
}
} }
Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 } Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }
{ {
maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1 || maxT > 1) {
maxThreadsRunning = 1;
}
} }
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()), Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples) maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
@ -192,30 +200,52 @@ namespace bayesnet {
{ {
setStates(states); setStates(states);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::mutex mtx;
std::condition_variable cv;
size_t activeThreads = 0;
const double n_samples = static_cast<double>(samples.size(1)); const double n_samples = static_cast<double>(samples.size(1));
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node) {
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; });
++activeThreads;
}
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
switch (smoothing) {
case Smoothing_t::ORIGINAL:
smoothing_factor = 1.0 / n_samples;
break;
case Smoothing_t::LAPLACE:
smoothing_factor = 1.0;
break;
case Smoothing_t::CESTNIK:
smoothing_factor = 1 / numStates;
break;
default:
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
}
node.second->computeCPT(samples, features, smoothing_factor, weights);
{
std::lock_guard<std::mutex> lock(mtx);
--activeThreads;
}
cv.notify_one();
};
for (auto& node : nodes) { for (auto& node : nodes) {
threads.emplace_back([this, &node, &weights, n_samples, smoothing]() { threads.emplace_back(worker, std::ref(node));
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor = 0.0;
switch (smoothing) {
case Smoothing_t::ORIGINAL:
smoothing_factor = 1.0 / n_samples;
break;
case Smoothing_t::LAPLACE:
smoothing_factor = 1.0;
break;
case Smoothing_t::CESTNIK: // Considering m=1 pa = 1/numStates
smoothing_factor = 1 / numStates;
break;
default:
throw std::invalid_argument("Smoothing method not recognized " + std::to_string(static_cast<int>(smoothing)));
}
node.second->computeCPT(samples, features, smoothing_factor, weights);
});
} }
for (auto& thread : threads) { for (auto& thread : threads) {
thread.join(); thread.join();
} }
fitted = true; fitted = true;
} }
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
@ -340,15 +370,32 @@ namespace bayesnet {
std::vector<double> result(classNumStates, 0.0); std::vector<double> result(classNumStates, 0.0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::mutex mtx; std::mutex mtx;
for (int i = 0; i < classNumStates; ++i) { std::condition_variable cv;
threads.emplace_back([this, &result, &evidence, i, &mtx]() { size_t activeThreads = 0;
auto completeEvidence = std::map<std::string, int>(evidence);
completeEvidence[getClassName()] = i; auto worker = [&](int i) {
double factor = computeFactor(completeEvidence); {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [&] { return activeThreads < maxThreadsRunning; });
++activeThreads;
}
auto completeEvidence = std::map<std::string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
{
std::lock_guard<std::mutex> lock(mtx); std::lock_guard<std::mutex> lock(mtx);
result[i] = factor; result[i] = factor;
}); --activeThreads;
}
cv.notify_one();
};
for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back(worker, i);
} }
for (auto& thread : threads) { for (auto& thread : threads) {
thread.join(); thread.join();
} }

View File

@ -56,7 +56,8 @@ namespace bayesnet {
private: private:
std::map<std::string, std::unique_ptr<Node>> nodes; std::map<std::string, std::unique_ptr<Node>> nodes;
bool fitted; bool fitted;
float maxThreads = 0.95; float maxThreads = 0.95; // Coefficient to multiply by the number of threads available
int maxThreadsRunning; // Effective max number of threads running
int classNumStates; int classNumStates;
std::vector<std::string> features; // Including classname std::vector<std::string> features; // Including classname
std::string className; std::string className;