Add thread control to vectors predict
This commit is contained in:
parent
59c1cf5b3b
commit
9a14133be5
@ -292,17 +292,18 @@ namespace bayesnet {
|
||||
std::vector<int> predictions(tsamples[0].size(), 0);
|
||||
std::vector<int> sample;
|
||||
std::vector<std::thread> threads;
|
||||
std::mutex mtx;
|
||||
auto& semaphore = CountingSemaphore::getInstance();
|
||||
auto worker = [&](const std::vector<int>& sample, const int row, std::vector<int>& predictions) {
|
||||
semaphore.acquire();
|
||||
auto worker = [&](const std::vector<int>& sample, const int row, int& prediction) {
|
||||
std::string threadName = "(V)PWorker-" + std::to_string(row);
|
||||
#if defined(__linux__)
|
||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||
#else
|
||||
pthread_setname_np(threadName.c_str());
|
||||
#endif
|
||||
auto classProbabilities = predict_sample(sample);
|
||||
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
|
||||
int predictedClass = distance(classProbabilities.begin(), maxElem);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
predictions[row] = predictedClass;
|
||||
}
|
||||
prediction = predictedClass;
|
||||
semaphore.release();
|
||||
};
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
@ -310,7 +311,8 @@ namespace bayesnet {
|
||||
for (int col = 0; col < tsamples.size(); ++col) {
|
||||
sample.push_back(tsamples[col][row]);
|
||||
}
|
||||
threads.emplace_back(worker, sample, row, std::ref(predictions));
|
||||
semaphore.acquire();
|
||||
threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
@ -329,14 +331,31 @@ namespace bayesnet {
|
||||
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
|
||||
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
|
||||
}
|
||||
std::vector<std::vector<double>> predictions;
|
||||
std::vector<std::vector<double>> predictions(tsamples[0].size(), std::vector<double>(classNumStates, 0.0));
|
||||
std::vector<int> sample;
|
||||
std::vector<std::thread> threads;
|
||||
auto& semaphore = CountingSemaphore::getInstance();
|
||||
auto worker = [&](const std::vector<int>& sample, int row, std::vector<double>& predictions) {
|
||||
std::string threadName = "(V)PWorker-" + std::to_string(row);
|
||||
#if defined(__linux__)
|
||||
pthread_setname_np(pthread_self(), threadName.c_str());
|
||||
#else
|
||||
pthread_setname_np(threadName.c_str());
|
||||
#endif
|
||||
std::vector<double> classProbabilities = predict_sample(sample);
|
||||
predictions = classProbabilities;
|
||||
semaphore.release();
|
||||
};
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
sample.clear();
|
||||
for (int col = 0; col < tsamples.size(); ++col) {
|
||||
sample.push_back(tsamples[col][row]);
|
||||
}
|
||||
predictions.push_back(predict_sample(sample));
|
||||
semaphore.acquire();
|
||||
threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ find_package(Torch REQUIRED)
|
||||
find_library(BayesNet NAMES BayesNet.a libBayesNet.a REQUIRED)
|
||||
|
||||
include_directories(
|
||||
lib/Files
|
||||
../tests/lib/Files
|
||||
lib/mdlp
|
||||
lib/json/include
|
||||
/usr/local/include
|
||||
|
@ -60,9 +60,9 @@ int main(int argc, char* argv[])
|
||||
auto clf = bayesnet::BoostAODE(false); // false for not using voting in predict
|
||||
std::cout << "Library version: " << clf.getVersion() << std::endl;
|
||||
tie(X, y, features, className, states) = loadDataset(file_name, true);
|
||||
clf.fit(X, y, features, className, states);
|
||||
clf.fit(X, y, features, className, states, bayesnet::Smoothing_t::LAPLACE);
|
||||
auto score = clf.score(X, y);
|
||||
std::cout << "File: " << file_name << " score: " << score << std::endl;
|
||||
std::cout << "File: " << file_name << " Model: BoostAODE score: " << score << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user