Remove threads

This commit is contained in:
Ricardo Montañana Gómez 2023-08-31 20:30:28 +02:00
parent 7c3e315ae7
commit 7806f961e2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 42 additions and 61 deletions

5
.vscode/launch.json vendored
View File

@ -25,12 +25,13 @@
"program": "${workspaceFolder}/build/src/Platform/main", "program": "${workspaceFolder}/build/src/Platform/main",
"args": [ "args": [
"-m", "-m",
"AODELd", "AODE",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets", "/Users/rmontanana/Code/discretizbench/datasets",
"--stratified", "--stratified",
"-d", "-d",
"wine" "letter",
"--discretize"
// "--hyperparameters", // "--hyperparameters",
// "{\"repeatSparent\": true, \"maxModels\": 12}" // "{\"repeatSparent\": true, \"maxModels\": 12}"
], ],

View File

@ -34,18 +34,22 @@ namespace bayesnet {
throw logic_error("Ensemble has not been fitted"); throw logic_error("Ensemble has not been fitted");
} }
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32); Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
//Create a threadpool // //Create a threadpool
auto threads{ vector<thread>() }; // auto threads{ vector<thread>() };
mutex mtx; // mutex mtx;
// for (auto i = 0; i < n_models; ++i) {
// threads.push_back(thread([&, i]() {
// auto ypredict = models[i]->predict(X);
// lock_guard<mutex> lock(mtx);
// y_pred.index_put_({ "...", i }, ypredict);
// }));
// Hacer voting aquí ? ? ?
// }
// for (auto& thread : threads) {
// thread.join();
// }
for (auto i = 0; i < n_models; ++i) { for (auto i = 0; i < n_models; ++i) {
threads.push_back(thread([&, i]() { y_pred.index_put_({ "...", i }, models[i]->predict(X));
auto ypredict = models[i]->predict(X);
lock_guard<mutex> lock(mtx);
y_pred.index_put_({ "...", i }, ypredict);
}));
}
for (auto& thread : threads) {
thread.join();
} }
return torch::tensor(voting(y_pred)); return torch::tensor(voting(y_pred));
} }

View File

@ -174,42 +174,10 @@ namespace bayesnet {
{ {
setStates(states); setStates(states);
laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads); for (auto& node : nodes) {
if (maxThreadsRunning < 1) { node.second->computeCPT(samples, features, laplaceSmoothing, weights);
maxThreadsRunning = 1; fitted = true;
} }
vector<thread> threads;
mutex mtx;
condition_variable cv;
int activeThreads = 0;
int nextNodeIndex = 0;
while (nextNodeIndex < nodes.size()) {
unique_lock<mutex> lock(mtx);
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads, &weights]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(samples, features, laplaceSmoothing, weights);
lock.lock();
nodes[pair.first] = std::move(pair.second);
lock.unlock();
}
lock_guard<mutex> lock(mtx);
--activeThreads;
cv.notify_one();
});
++activeThreads;
}
for (auto& thread : threads) {
thread.join();
}
fitted = true;
} }
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
{ {
@ -331,19 +299,25 @@ namespace bayesnet {
vector<double> Network::exactInference(map<string, int>& evidence) vector<double> Network::exactInference(map<string, int>& evidence)
{ {
vector<double> result(classNumStates, 0.0); vector<double> result(classNumStates, 0.0);
vector<thread> threads; // vector<thread> threads;
mutex mtx; // mutex mtx;
// for (int i = 0; i < classNumStates; ++i) {
// threads.emplace_back([this, &result, &evidence, i, &mtx]() {
// auto completeEvidence = map<string, int>(evidence);
// completeEvidence[getClassName()] = i;
// double factor = computeFactor(completeEvidence);
// lock_guard<mutex> lock(mtx);
// result[i] = factor;
// });
// }
// for (auto& thread : threads) {
// thread.join();
// }
for (int i = 0; i < classNumStates; ++i) { for (int i = 0; i < classNumStates; ++i) {
threads.emplace_back([this, &result, &evidence, i, &mtx]() { auto completeEvidence = map<string, int>(evidence);
auto completeEvidence = map<string, int>(evidence); completeEvidence[getClassName()] = i;
completeEvidence[getClassName()] = i; double factor = computeFactor(completeEvidence);
double factor = computeFactor(completeEvidence); result[i] = factor;
lock_guard<mutex> lock(mtx);
result[i] = factor;
});
}
for (auto& thread : threads) {
thread.join();
} }
// Normalize result // Normalize result
double sum = accumulate(result.begin(), result.end(), 0.0); double sum = accumulate(result.begin(), result.end(), 0.0);

View File

@ -27,6 +27,7 @@ namespace bayesnet {
Network(); Network();
explicit Network(float); explicit Network(float);
explicit Network(Network&); explicit Network(Network&);
~Network() = default;
torch::Tensor& getSamples(); torch::Tensor& getSamples();
float getmaxThreads(); float getmaxThreads();
void addNode(const string&); void addNode(const string&);

View File

@ -179,6 +179,7 @@ namespace platform {
result.addTimeTrain(train_time[item].item<double>()); result.addTimeTrain(train_time[item].item<double>());
result.addTimeTest(test_time[item].item<double>()); result.addTimeTest(test_time[item].item<double>());
item++; item++;
clf.reset();
} }
cout << "end. " << flush; cout << "end. " << flush;
} }

View File

@ -26,7 +26,7 @@ namespace platform {
instance = it->second(); instance = it->second();
// wrap instance in a shared ptr and return // wrap instance in a shared ptr and return
if (instance != nullptr) if (instance != nullptr)
return shared_ptr<bayesnet::BaseClassifier>(instance); return unique_ptr<bayesnet::BaseClassifier>(instance);
else else
return nullptr; return nullptr;
} }