Add threads to fit

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 12:40:47 +02:00
parent b6c21c21e2
commit 0b33c6c04a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 63 additions and 19 deletions

View File

@ -221,7 +221,7 @@ int main(int argc, char** argv)
cout << endl;
cout << "Class name: " << className << endl;
// Build Network
auto network = bayesnet::Network();
auto network = bayesnet::Network(1.0);
build_network(network, network_name, maxes);
network.fit(Xd, y, features, className);
cout << "Hello, Bayesian Networks!" << endl;

View File

@ -2,9 +2,10 @@
#include <mutex>
#include "Network.h"
namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates())
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
{
for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second);
@ -16,6 +17,10 @@ namespace bayesnet {
delete pair.second;
}
}
float Network::getmaxThreads()
{
return maxThreads;
}
void Network::addNode(string name, int numStates)
{
if (nodes.find(name) != nodes.end()) {
@ -93,24 +98,62 @@ namespace bayesnet {
{
return nodes;
}
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
{
features = featureNames;
this->className = className;
dataset.clear();
// Build dataset
for (int i = 0; i < featureNames.size(); ++i) {
this->dataset[featureNames[i]] = dataset[i];
dataset[featureNames[i]] = input_data[i];
}
this->dataset[className] = labels;
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1;
estimateParameters();
dataset[className] = labels;
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
}
cout << "Using " << maxThreadsRunning << " threads" << " maxThreads: " << maxThreads << endl;
vector<thread> threads;
mutex mtx;
condition_variable cv;
int activeThreads = 0;
int nextNodeIndex = 0;
while (nextNodeIndex < nodes.size()) {
unique_lock<mutex> lock(mtx);
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
void Network::estimateParameters()
{
auto dimensions = vector<int64_t>();
for (auto [name, node] : nodes) {
node->computeCPT(dataset, laplaceSmoothing);
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(dataset, laplaceSmoothing);
lock.lock();
nodes[pair.first] = pair.second;
lock.unlock();
}
lock_guard<mutex> lock(mtx);
--activeThreads;
cv.notify_one();
});
++activeThreads;
}
for (auto& thread : threads) {
thread.join();
}
}
@ -205,7 +248,6 @@ namespace bayesnet {
for (double& value : result) {
value /= sum;
}
return result;
}
}

View File

@ -11,6 +11,7 @@ namespace bayesnet {
map<string, Node*> nodes;
map<string, vector<int>> dataset;
Node* root;
float maxThreads;
int classNumStates;
vector<string> features;
string className;
@ -21,9 +22,11 @@ namespace bayesnet {
double computeFactor(map<string, int>&);
public:
Network();
Network(int);
Network(float, int);
Network(float);
Network(Network&);
~Network();
float getmaxThreads();
void addNode(string, int);
void addEdge(const string, const string);
map<string, Node*>& getNodes();
@ -31,7 +34,6 @@ namespace bayesnet {
int getClassNumStates();
string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void estimateParameters();
void setRoot(string);
Node* getRoot();
vector<int> predict(const vector<vector<int>>&);