Add threads to exactInference

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 11:59:48 +02:00
parent 0d27ecd253
commit b6c21c21e2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 26 additions and 8 deletions

View File

@ -1,6 +1,7 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <torch/torch.h> #include <torch/torch.h>
#include <thread>
#include <getopt.h> #include <getopt.h>
#include "ArffFiles.h" #include "ArffFiles.h"
#include "Network.h" #include "Network.h"
@ -228,5 +229,7 @@ int main(int argc, char** argv)
//showCPDS(network); //showCPDS(network);
cout << "Score: " << network.score(Xd, y) << endl; cout << "Score: " << network.score(Xd, y) << endl;
cout << "PyTorch version: " << TORCH_VERSION << endl; cout << "PyTorch version: " << TORCH_VERSION << endl;
unsigned int nthreads = std::thread::hardware_concurrency();
cout << "Computer has " << nthreads << " cores." << endl;
return 0; return 0;
} }

View File

@ -1,3 +1,5 @@
#include <thread>
#include <mutex>
#include "Network.h" #include "Network.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {} Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
@ -179,18 +181,31 @@ namespace bayesnet {
} }
vector<double> Network::exactInference(map<string, int>& evidence) vector<double> Network::exactInference(map<string, int>& evidence)
{ {
vector<double> result; vector<double> result(classNumStates, 0.0);
vector<thread> threads;
mutex mtx;
for (int i = 0; i < classNumStates; ++i) { for (int i = 0; i < classNumStates; ++i) {
result.push_back(1.0); threads.emplace_back([this, &result, &evidence, i, &mtx]() {
auto complete_evidence = map<string, int>(evidence); auto completeEvidence = map<string, int>(evidence);
complete_evidence[getClassName()] = i; completeEvidence[getClassName()] = i;
result[i] = computeFactor(complete_evidence); double factor = computeFactor(completeEvidence);
lock_guard<mutex> lock(mtx);
result[i] = factor;
});
} }
for (auto& thread : threads) {
thread.join();
}
// Normalize result // Normalize result
auto sum = accumulate(result.begin(), result.end(), 0.0); double sum = accumulate(result.begin(), result.end(), 0.0);
for (int i = 0; i < result.size(); ++i) { for (double& value : result) {
result[i] /= sum; value /= sum;
} }
return result; return result;
} }
} }