Add threads to exactInference

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 11:59:48 +02:00
parent 0d27ecd253
commit b6c21c21e2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 26 additions and 8 deletions

View File

@ -1,6 +1,7 @@
#include <iostream>
#include <string>
#include <torch/torch.h>
#include <thread>
#include <getopt.h>
#include "ArffFiles.h"
#include "Network.h"
@ -228,5 +229,7 @@ int main(int argc, char** argv)
//showCPDS(network);
cout << "Score: " << network.score(Xd, y) << endl;
cout << "PyTorch version: " << TORCH_VERSION << endl;
unsigned int nthreads = std::thread::hardware_concurrency();
cout << "Computer has " << nthreads << " cores." << endl;
return 0;
}

View File

@ -1,3 +1,5 @@
#include <thread>
#include <mutex>
#include "Network.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
@ -179,18 +181,31 @@ namespace bayesnet {
}
vector<double> Network::exactInference(map<string, int>& evidence)
{
vector<double> result;
vector<double> result(classNumStates, 0.0);
vector<thread> threads;
mutex mtx;
for (int i = 0; i < classNumStates; ++i) {
result.push_back(1.0);
auto complete_evidence = map<string, int>(evidence);
complete_evidence[getClassName()] = i;
result[i] = computeFactor(complete_evidence);
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
auto completeEvidence = map<string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
lock_guard<mutex> lock(mtx);
result[i] = factor;
});
}
for (auto& thread : threads) {
thread.join();
}
// Normalize result
auto sum = accumulate(result.begin(), result.end(), 0.0);
for (int i = 0; i < result.size(); ++i) {
result[i] /= sum;
double sum = accumulate(result.begin(), result.end(), 0.0);
for (double& value : result) {
value /= sum;
}
return result;
}
}