diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index da8898c..b9fc659 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -133,9 +133,9 @@ namespace bayesnet { { // Set states to every Node in the network for (int i = 0; i < features.size(); ++i) { - nodes[features[i]]->setNumStates(states.at(features[i]).size()); + nodes.at(features.at(i))->setNumStates(states.at(features[i]).size()); } - classNumStates = nodes[className]->getNumStates(); + classNumStates = nodes.at(className)->getNumStates(); } // X comes in nxm, where n is the number of features and m the number of samples void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const vector& featureNames, const string& className, const map>& states) @@ -174,10 +174,16 @@ namespace bayesnet { { setStates(states); laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation + vector threads; for (auto& node : nodes) { - node.second->computeCPT(samples, features, laplaceSmoothing, weights); - fitted = true; + threads.emplace_back([this, &node, &weights]() { + node.second->computeCPT(samples, features, laplaceSmoothing, weights); + }); } + for (auto& thread : threads) { + thread.join(); + } + fitted = true; } torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) {