Remove threads

This commit is contained in:
Ricardo Montañana Gómez 2023-08-31 20:30:28 +02:00
parent 7c3e315ae7
commit 7806f961e2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 42 additions and 61 deletions

5
.vscode/launch.json vendored
View File

@ -25,12 +25,13 @@
"program": "${workspaceFolder}/build/src/Platform/main",
"args": [
"-m",
"AODELd",
"AODE",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets",
"--stratified",
"-d",
"wine"
"letter",
"--discretize"
// "--hyperparameters",
// "{\"repeatSparent\": true, \"maxModels\": 12}"
],

View File

@ -34,18 +34,22 @@ namespace bayesnet {
throw logic_error("Ensemble has not been fitted");
}
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
//Create a threadpool
auto threads{ vector<thread>() };
mutex mtx;
// //Create a threadpool
// auto threads{ vector<thread>() };
// mutex mtx;
// for (auto i = 0; i < n_models; ++i) {
// threads.push_back(thread([&, i]() {
// auto ypredict = models[i]->predict(X);
// lock_guard<mutex> lock(mtx);
// y_pred.index_put_({ "...", i }, ypredict);
// }));
// Hacer voting aquí ? ? ?
// }
// for (auto& thread : threads) {
// thread.join();
// }
for (auto i = 0; i < n_models; ++i) {
threads.push_back(thread([&, i]() {
auto ypredict = models[i]->predict(X);
lock_guard<mutex> lock(mtx);
y_pred.index_put_({ "...", i }, ypredict);
}));
}
for (auto& thread : threads) {
thread.join();
y_pred.index_put_({ "...", i }, models[i]->predict(X));
}
return torch::tensor(voting(y_pred));
}

View File

@ -174,42 +174,10 @@ namespace bayesnet {
{
setStates(states);
laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
for (auto& node : nodes) {
node.second->computeCPT(samples, features, laplaceSmoothing, weights);
fitted = true;
}
vector<thread> threads;
mutex mtx;
condition_variable cv;
int activeThreads = 0;
int nextNodeIndex = 0;
while (nextNodeIndex < nodes.size()) {
unique_lock<mutex> lock(mtx);
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads, &weights]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(samples, features, laplaceSmoothing, weights);
lock.lock();
nodes[pair.first] = std::move(pair.second);
lock.unlock();
}
lock_guard<mutex> lock(mtx);
--activeThreads;
cv.notify_one();
});
++activeThreads;
}
for (auto& thread : threads) {
thread.join();
}
fitted = true;
}
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
{
@ -331,19 +299,25 @@ namespace bayesnet {
vector<double> Network::exactInference(map<string, int>& evidence)
{
vector<double> result(classNumStates, 0.0);
vector<thread> threads;
mutex mtx;
// vector<thread> threads;
// mutex mtx;
// for (int i = 0; i < classNumStates; ++i) {
// threads.emplace_back([this, &result, &evidence, i, &mtx]() {
// auto completeEvidence = map<string, int>(evidence);
// completeEvidence[getClassName()] = i;
// double factor = computeFactor(completeEvidence);
// lock_guard<mutex> lock(mtx);
// result[i] = factor;
// });
// }
// for (auto& thread : threads) {
// thread.join();
// }
for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back([this, &result, &evidence, i, &mtx]() {
auto completeEvidence = map<string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
lock_guard<mutex> lock(mtx);
result[i] = factor;
});
}
for (auto& thread : threads) {
thread.join();
auto completeEvidence = map<string, int>(evidence);
completeEvidence[getClassName()] = i;
double factor = computeFactor(completeEvidence);
result[i] = factor;
}
// Normalize result
double sum = accumulate(result.begin(), result.end(), 0.0);

View File

@ -27,6 +27,7 @@ namespace bayesnet {
Network();
explicit Network(float);
explicit Network(Network&);
~Network() = default;
torch::Tensor& getSamples();
float getmaxThreads();
void addNode(const string&);

View File

@ -179,6 +179,7 @@ namespace platform {
result.addTimeTrain(train_time[item].item<double>());
result.addTimeTest(test_time[item].item<double>());
item++;
clf.reset();
}
cout << "end. " << flush;
}

View File

@ -26,7 +26,7 @@ namespace platform {
instance = it->second();
// wrap instance in a shared ptr and return
if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance);
return unique_ptr<bayesnet::BaseClassifier>(instance);
else
return nullptr;
}