Add threads to exactInference
This commit is contained in:
parent
0d27ecd253
commit
b6c21c21e2
@ -1,6 +1,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
#include <thread>
|
||||||
#include <getopt.h>
|
#include <getopt.h>
|
||||||
#include "ArffFiles.h"
|
#include "ArffFiles.h"
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
@ -228,5 +229,7 @@ int main(int argc, char** argv)
|
|||||||
//showCPDS(network);
|
//showCPDS(network);
|
||||||
cout << "Score: " << network.score(Xd, y) << endl;
|
cout << "Score: " << network.score(Xd, y) << endl;
|
||||||
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
cout << "PyTorch version: " << TORCH_VERSION << endl;
|
||||||
|
unsigned int nthreads = std::thread::hardware_concurrency();
|
||||||
|
cout << "Computer has " << nthreads << " cores." << endl;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
@ -1,3 +1,5 @@
|
|||||||
|
#include <thread>
|
||||||
|
#include <mutex>
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||||
@ -179,18 +181,31 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
vector<double> Network::exactInference(map<string, int>& evidence)
|
vector<double> Network::exactInference(map<string, int>& evidence)
|
||||||
{
|
{
|
||||||
vector<double> result;
|
vector<double> result(classNumStates, 0.0);
|
||||||
|
vector<thread> threads;
|
||||||
|
mutex mtx;
|
||||||
|
|
||||||
for (int i = 0; i < classNumStates; ++i) {
|
for (int i = 0; i < classNumStates; ++i) {
|
||||||
result.push_back(1.0);
|
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
|
||||||
auto complete_evidence = map<string, int>(evidence);
|
auto completeEvidence = map<string, int>(evidence);
|
||||||
complete_evidence[getClassName()] = i;
|
completeEvidence[getClassName()] = i;
|
||||||
result[i] = computeFactor(complete_evidence);
|
double factor = computeFactor(completeEvidence);
|
||||||
|
|
||||||
|
lock_guard<mutex> lock(mtx);
|
||||||
|
result[i] = factor;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
// Normalize result
|
// Normalize result
|
||||||
auto sum = accumulate(result.begin(), result.end(), 0.0);
|
double sum = accumulate(result.begin(), result.end(), 0.0);
|
||||||
for (int i = 0; i < result.size(); ++i) {
|
for (double& value : result) {
|
||||||
result[i] /= sum;
|
value /= sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user