diff --git a/bayesnet/network/Network.cc b/bayesnet/network/Network.cc index b399782..ceede5b 100644 --- a/bayesnet/network/Network.cc +++ b/bayesnet/network/Network.cc @@ -292,17 +292,18 @@ namespace bayesnet { std::vector predictions(tsamples[0].size(), 0); std::vector sample; std::vector threads; - std::mutex mtx; auto& semaphore = CountingSemaphore::getInstance(); - auto worker = [&](const std::vector& sample, const int row, std::vector& predictions) { - semaphore.acquire(); + auto worker = [&](const std::vector& sample, const int row, int& prediction) { + std::string threadName = "(V)PWorker-" + std::to_string(row); +#if defined(__linux__) + pthread_setname_np(pthread_self(), threadName.c_str()); +#else + pthread_setname_np(threadName.c_str()); +#endif auto classProbabilities = predict_sample(sample); auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); int predictedClass = distance(classProbabilities.begin(), maxElem); - { - std::lock_guard lock(mtx); - predictions[row] = predictedClass; - } + prediction = predictedClass; semaphore.release(); }; for (int row = 0; row < tsamples[0].size(); ++row) { @@ -310,7 +311,8 @@ namespace bayesnet { for (int col = 0; col < tsamples.size(); ++col) { sample.push_back(tsamples[col][row]); } - threads.emplace_back(worker, sample, row, std::ref(predictions)); + semaphore.acquire(); + threads.emplace_back(worker, sample, row, std::ref(predictions[row])); } for (auto& thread : threads) { thread.join(); @@ -329,14 +331,31 @@ namespace bayesnet { throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) + ") does not match the number of features (" + std::to_string(features.size() - 1) + ")"); } - std::vector> predictions; + std::vector> predictions(tsamples[0].size(), std::vector(classNumStates, 0.0)); std::vector sample; + std::vector threads; + auto& semaphore = CountingSemaphore::getInstance(); + auto worker = [&](const std::vector& sample, int row, std::vector& predictions) { + std::string threadName = "(V)PWorker-" + std::to_string(row); +#if defined(__linux__) + pthread_setname_np(pthread_self(), threadName.c_str()); +#else + pthread_setname_np(threadName.c_str()); +#endif + std::vector classProbabilities = predict_sample(sample); + predictions = classProbabilities; + semaphore.release(); + }; for (int row = 0; row < tsamples[0].size(); ++row) { sample.clear(); for (int col = 0; col < tsamples.size(); ++col) { sample.push_back(tsamples[col][row]); } - predictions.push_back(predict_sample(sample)); + semaphore.acquire(); + threads.emplace_back(worker, sample, row, std::ref(predictions[row])); + } + for (auto& thread : threads) { + thread.join(); } return predictions; } diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index d50030e..b36cfc7 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -8,7 +8,7 @@ find_package(Torch REQUIRED) find_library(BayesNet NAMES BayesNet.a libBayesNet.a REQUIRED) include_directories( - lib/Files + ../tests/lib/Files lib/mdlp lib/json/include /usr/local/include diff --git a/sample/sample.cc b/sample/sample.cc index 511230f..478ff85 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -60,9 +60,9 @@ int main(int argc, char* argv[]) auto clf = bayesnet::BoostAODE(false); // false for not using voting in predict std::cout << "Library version: " << clf.getVersion() << std::endl; tie(X, y, features, className, states) = loadDataset(file_name, true); - clf.fit(X, y, features, className, states); + clf.fit(X, y, features, className, states, bayesnet::Smoothing_t::LAPLACE); auto score = clf.score(X, y); - std::cout << "File: " << file_name << " score: " << score << std::endl; + std::cout << "File: " << file_name << " Model: BoostAODE score: " << score << std::endl; return 0; }