Add threads without limit to network fit

This commit is contained in:
Ricardo Montañana Gómez 2023-09-04 21:24:11 +02:00
parent 05b670dfc0
commit 0b7beda78c
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -133,9 +133,9 @@ namespace bayesnet {
{
// Set states to every Node in the network
for (int i = 0; i < features.size(); ++i) {
nodes[features[i]]->setNumStates(states.at(features[i]).size());
nodes.at(features.at(i))->setNumStates(states.at(features[i]).size());
}
classNumStates = nodes[className]->getNumStates();
classNumStates = nodes.at(className)->getNumStates();
}
// X comes in nxm, where n is the number of features and m the number of samples
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector<string>& featureNames, const string& className, const map<string, vector<int>>& states)
@ -174,10 +174,16 @@ namespace bayesnet {
{
setStates(states);
laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
vector<thread> threads;
for (auto& node : nodes) {
node.second->computeCPT(samples, features, laplaceSmoothing, weights);
fitted = true;
threads.emplace_back([this, &node, &weights]() {
node.second->computeCPT(samples, features, laplaceSmoothing, weights);
});
}
for (auto& thread : threads) {
thread.join();
}
fitted = true;
}
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
{