diff --git a/sample/main.cc b/sample/main.cc index 2881d2d..00aeac5 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -221,7 +221,7 @@ int main(int argc, char** argv) cout << endl; cout << "Class name: " << className << endl; // Build Network - auto network = bayesnet::Network(); + auto network = bayesnet::Network(1.0); build_network(network, network_name, maxes); network.fit(Xd, y, features, className); cout << "Hello, Bayesian Networks!" << endl; diff --git a/src/Network.cc b/src/Network.cc index 441e431..b29465d 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -2,9 +2,10 @@ #include #include "Network.h" namespace bayesnet { - Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0) {} - Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className(""), classNumStates(0) {} - Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()) + Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(0.8) {} + Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} + Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} + Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) { for (auto& pair : other.nodes) { nodes[pair.first] = new Node(*pair.second); @@ -16,6 +17,10 @@ namespace bayesnet { delete pair.second; } } + float Network::getmaxThreads() + { + return maxThreads; + } void Network::addNode(string name, int numStates) { if (nodes.find(name) != nodes.end()) { @@ -93,24 +98,62 @@ namespace bayesnet { { return nodes; } - void Network::fit(const vector>& dataset, const vector& labels, const vector& featureNames, const string& className) + void Network::fit(const vector>& input_data, const vector& labels, const vector& featureNames, const string& className) { features = featureNames; this->className = className; + dataset.clear(); + // Build dataset for (int i = 0; i < featureNames.size(); ++i) { - this->dataset[featureNames[i]] = dataset[i]; + dataset[featureNames[i]] = input_data[i]; } - this->dataset[className] = labels; - this->classNumStates = *max_element(labels.begin(), labels.end()) + 1; - estimateParameters(); - } + dataset[className] = labels; + classNumStates = *max_element(labels.begin(), labels.end()) + 1; + int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); + if (maxThreadsRunning < 1) { + maxThreadsRunning = 1; + } + cout << "Using " << maxThreadsRunning << " threads" << " maxThreads: " << maxThreads << endl; + vector threads; + mutex mtx; + condition_variable cv; + int activeThreads = 0; + int nextNodeIndex = 0; - void Network::estimateParameters() - { - auto dimensions = vector(); - for (auto [name, node] : nodes) { - node->computeCPT(dataset, laplaceSmoothing); + while (nextNodeIndex < nodes.size()) { + unique_lock lock(mtx); + cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; }); + + if (nextNodeIndex >= nodes.size()) { + break; // No more work remaining + } + + threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() { + while (true) { + unique_lock lock(mtx); + if (nextNodeIndex >= nodes.size()) { + break; // No more work remaining + } + auto& pair = *std::next(nodes.begin(), nextNodeIndex); + ++nextNodeIndex; + lock.unlock(); + + pair.second->computeCPT(dataset, laplaceSmoothing); + + lock.lock(); + nodes[pair.first] = pair.second; + lock.unlock(); + } + lock_guard lock(mtx); + --activeThreads; + cv.notify_one(); + }); + + ++activeThreads; + } + for (auto& thread : threads) { + thread.join(); } } @@ -193,7 +236,7 @@ namespace bayesnet { lock_guard lock(mtx); result[i] = factor; - }); + }); } for (auto& thread : threads) { @@ -205,7 +248,6 @@ namespace bayesnet { for (double& value : result) { value /= sum; } - return result; } } diff --git a/src/Network.h b/src/Network.h index 2c01e12..9ba531c 100644 --- a/src/Network.h +++ b/src/Network.h @@ -11,6 +11,7 @@ namespace bayesnet { map nodes; map> dataset; Node* root; + float maxThreads; int classNumStates; vector features; string className; @@ -21,9 +22,11 @@ namespace bayesnet { double computeFactor(map&); public: Network(); - Network(int); + Network(float, int); + Network(float); Network(Network&); ~Network(); + float getmaxThreads(); void addNode(string, int); void addEdge(const string, const string); map& getNodes(); @@ -31,7 +34,6 @@ namespace bayesnet { int getClassNumStates(); string getClassName(); void fit(const vector>&, const vector&, const vector&, const string&); - void estimateParameters(); void setRoot(string); Node* getRoot(); vector predict(const vector>&);