Add threads to fit

This commit is contained in:
Ricardo Montañana Gómez 2023-07-06 12:40:47 +02:00
parent b6c21c21e2
commit 0b33c6c04a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 63 additions and 19 deletions

View File

@ -221,7 +221,7 @@ int main(int argc, char** argv)
cout << endl; cout << endl;
cout << "Class name: " << className << endl; cout << "Class name: " << className << endl;
// Build Network // Build Network
auto network = bayesnet::Network(); auto network = bayesnet::Network(1.0);
build_network(network, network_name, maxes); build_network(network, network_name, maxes);
network.fit(Xd, y, features, className); network.fit(Xd, y, features, className);
cout << "Hello, Bayesian Networks!" << endl; cout << "Hello, Bayesian Networks!" << endl;

View File

@ -2,9 +2,10 @@
#include <mutex> #include <mutex>
#include "Network.h" #include "Network.h"
namespace bayesnet { namespace bayesnet {
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {} Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {} Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()) Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
{ {
for (auto& pair : other.nodes) { for (auto& pair : other.nodes) {
nodes[pair.first] = new Node(*pair.second); nodes[pair.first] = new Node(*pair.second);
@ -16,6 +17,10 @@ namespace bayesnet {
delete pair.second; delete pair.second;
} }
} }
float Network::getmaxThreads()
{
return maxThreads;
}
void Network::addNode(string name, int numStates) void Network::addNode(string name, int numStates)
{ {
if (nodes.find(name) != nodes.end()) { if (nodes.find(name) != nodes.end()) {
@ -93,24 +98,62 @@ namespace bayesnet {
{ {
return nodes; return nodes;
} }
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className) void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
{ {
features = featureNames; features = featureNames;
this->className = className; this->className = className;
dataset.clear();
// Build dataset // Build dataset
for (int i = 0; i < featureNames.size(); ++i) { for (int i = 0; i < featureNames.size(); ++i) {
this->dataset[featureNames[i]] = dataset[i]; dataset[featureNames[i]] = input_data[i];
} }
this->dataset[className] = labels; dataset[className] = labels;
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1; classNumStates = *max_element(labels.begin(), labels.end()) + 1;
estimateParameters(); int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
} if (maxThreadsRunning < 1) {
maxThreadsRunning = 1;
}
cout << "Using " << maxThreadsRunning << " threads" << " maxThreads: " << maxThreads << endl;
vector<thread> threads;
mutex mtx;
condition_variable cv;
int activeThreads = 0;
int nextNodeIndex = 0;
void Network::estimateParameters() while (nextNodeIndex < nodes.size()) {
{ unique_lock<mutex> lock(mtx);
auto dimensions = vector<int64_t>(); cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
for (auto [name, node] : nodes) {
node->computeCPT(dataset, laplaceSmoothing); if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
while (true) {
unique_lock<mutex> lock(mtx);
if (nextNodeIndex >= nodes.size()) {
break; // No more work remaining
}
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
++nextNodeIndex;
lock.unlock();
pair.second->computeCPT(dataset, laplaceSmoothing);
lock.lock();
nodes[pair.first] = pair.second;
lock.unlock();
}
lock_guard<mutex> lock(mtx);
--activeThreads;
cv.notify_one();
});
++activeThreads;
}
for (auto& thread : threads) {
thread.join();
} }
} }
@ -193,7 +236,7 @@ namespace bayesnet {
lock_guard<mutex> lock(mtx); lock_guard<mutex> lock(mtx);
result[i] = factor; result[i] = factor;
}); });
} }
for (auto& thread : threads) { for (auto& thread : threads) {
@ -205,7 +248,6 @@ namespace bayesnet {
for (double& value : result) { for (double& value : result) {
value /= sum; value /= sum;
} }
return result; return result;
} }
} }

View File

@ -11,6 +11,7 @@ namespace bayesnet {
map<string, Node*> nodes; map<string, Node*> nodes;
map<string, vector<int>> dataset; map<string, vector<int>> dataset;
Node* root; Node* root;
float maxThreads;
int classNumStates; int classNumStates;
vector<string> features; vector<string> features;
string className; string className;
@ -21,9 +22,11 @@ namespace bayesnet {
double computeFactor(map<string, int>&); double computeFactor(map<string, int>&);
public: public:
Network(); Network();
Network(int); Network(float, int);
Network(float);
Network(Network&); Network(Network&);
~Network(); ~Network();
float getmaxThreads();
void addNode(string, int); void addNode(string, int);
void addEdge(const string, const string); void addEdge(const string, const string);
map<string, Node*>& getNodes(); map<string, Node*>& getNodes();
@ -31,7 +34,6 @@ namespace bayesnet {
int getClassNumStates(); int getClassNumStates();
string getClassName(); string getClassName();
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&); void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void estimateParameters();
void setRoot(string); void setRoot(string);
Node* getRoot(); Node* getRoot();
vector<int> predict(const vector<vector<int>>&); vector<int> predict(const vector<vector<int>>&);