From b6c21c21e2a77aeefca4d7633a92e8a132adc237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 6 Jul 2023 11:59:48 +0200 Subject: [PATCH] Add threads to exactInference --- sample/main.cc | 3 +++ src/Network.cc | 31 +++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/sample/main.cc b/sample/main.cc index 0e60277..2881d2d 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include "ArffFiles.h" #include "Network.h" @@ -228,5 +229,7 @@ int main(int argc, char** argv) //showCPDS(network); cout << "Score: " << network.score(Xd, y) << endl; cout << "PyTorch version: " << TORCH_VERSION << endl; + unsigned int nthreads = std::thread::hardware_concurrency(); + cout << "Computer has " << nthreads << " cores." << endl; return 0; } \ No newline at end of file diff --git a/src/Network.cc b/src/Network.cc index 7296ba4..441e431 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -1,3 +1,5 @@ +#include +#include #include "Network.h" namespace bayesnet { Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0) {} @@ -179,18 +181,31 @@ namespace bayesnet { } vector Network::exactInference(map& evidence) { - vector result; + vector result(classNumStates, 0.0); + vector threads; + mutex mtx; + for (int i = 0; i < classNumStates; ++i) { - result.push_back(1.0); - auto complete_evidence = map(evidence); - complete_evidence[getClassName()] = i; - result[i] = computeFactor(complete_evidence); + threads.emplace_back([this, &result, &evidence, i, &mtx]() { + auto completeEvidence = map(evidence); + completeEvidence[getClassName()] = i; + double factor = computeFactor(completeEvidence); + + lock_guard lock(mtx); + result[i] = factor; + }); } + + for (auto& thread : threads) { + thread.join(); + } + // Normalize result - auto sum = accumulate(result.begin(), result.end(), 0.0); - for (int i = 0; i < result.size(); ++i) { - result[i] /= sum; + double sum = accumulate(result.begin(), result.end(), 0.0); + for (double& value : result) { + value /= sum; } + return result; } }