Add threads to fit
This commit is contained in:
parent
b6c21c21e2
commit
0b33c6c04a
@ -221,7 +221,7 @@ int main(int argc, char** argv)
|
||||
cout << endl;
|
||||
cout << "Class name: " << className << endl;
|
||||
// Build Network
|
||||
auto network = bayesnet::Network();
|
||||
auto network = bayesnet::Network(1.0);
|
||||
build_network(network, network_name, maxes);
|
||||
network.fit(Xd, y, features, className);
|
||||
cout << "Hello, Bayesian Networks!" << endl;
|
||||
|
@ -2,9 +2,10 @@
|
||||
#include <mutex>
|
||||
#include "Network.h"
|
||||
namespace bayesnet {
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates())
|
||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
|
||||
Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
|
||||
{
|
||||
for (auto& pair : other.nodes) {
|
||||
nodes[pair.first] = new Node(*pair.second);
|
||||
@ -16,6 +17,10 @@ namespace bayesnet {
|
||||
delete pair.second;
|
||||
}
|
||||
}
|
||||
float Network::getmaxThreads()
|
||||
{
|
||||
return maxThreads;
|
||||
}
|
||||
void Network::addNode(string name, int numStates)
|
||||
{
|
||||
if (nodes.find(name) != nodes.end()) {
|
||||
@ -93,24 +98,62 @@ namespace bayesnet {
|
||||
{
|
||||
return nodes;
|
||||
}
|
||||
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
||||
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
||||
{
|
||||
features = featureNames;
|
||||
this->className = className;
|
||||
dataset.clear();
|
||||
|
||||
// Build dataset
|
||||
for (int i = 0; i < featureNames.size(); ++i) {
|
||||
this->dataset[featureNames[i]] = dataset[i];
|
||||
dataset[featureNames[i]] = input_data[i];
|
||||
}
|
||||
this->dataset[className] = labels;
|
||||
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
||||
estimateParameters();
|
||||
dataset[className] = labels;
|
||||
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
||||
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
||||
if (maxThreadsRunning < 1) {
|
||||
maxThreadsRunning = 1;
|
||||
}
|
||||
cout << "Using " << maxThreadsRunning << " threads" << " maxThreads: " << maxThreads << endl;
|
||||
vector<thread> threads;
|
||||
mutex mtx;
|
||||
condition_variable cv;
|
||||
int activeThreads = 0;
|
||||
int nextNodeIndex = 0;
|
||||
|
||||
while (nextNodeIndex < nodes.size()) {
|
||||
unique_lock<mutex> lock(mtx);
|
||||
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
|
||||
|
||||
if (nextNodeIndex >= nodes.size()) {
|
||||
break; // No more work remaining
|
||||
}
|
||||
|
||||
void Network::estimateParameters()
|
||||
{
|
||||
auto dimensions = vector<int64_t>();
|
||||
for (auto [name, node] : nodes) {
|
||||
node->computeCPT(dataset, laplaceSmoothing);
|
||||
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
|
||||
while (true) {
|
||||
unique_lock<mutex> lock(mtx);
|
||||
if (nextNodeIndex >= nodes.size()) {
|
||||
break; // No more work remaining
|
||||
}
|
||||
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
|
||||
++nextNodeIndex;
|
||||
lock.unlock();
|
||||
|
||||
pair.second->computeCPT(dataset, laplaceSmoothing);
|
||||
|
||||
lock.lock();
|
||||
nodes[pair.first] = pair.second;
|
||||
lock.unlock();
|
||||
}
|
||||
lock_guard<mutex> lock(mtx);
|
||||
--activeThreads;
|
||||
cv.notify_one();
|
||||
});
|
||||
|
||||
++activeThreads;
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,7 +248,6 @@ namespace bayesnet {
|
||||
for (double& value : result) {
|
||||
value /= sum;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ namespace bayesnet {
|
||||
map<string, Node*> nodes;
|
||||
map<string, vector<int>> dataset;
|
||||
Node* root;
|
||||
float maxThreads;
|
||||
int classNumStates;
|
||||
vector<string> features;
|
||||
string className;
|
||||
@ -21,9 +22,11 @@ namespace bayesnet {
|
||||
double computeFactor(map<string, int>&);
|
||||
public:
|
||||
Network();
|
||||
Network(int);
|
||||
Network(float, int);
|
||||
Network(float);
|
||||
Network(Network&);
|
||||
~Network();
|
||||
float getmaxThreads();
|
||||
void addNode(string, int);
|
||||
void addEdge(const string, const string);
|
||||
map<string, Node*>& getNodes();
|
||||
@ -31,7 +34,6 @@ namespace bayesnet {
|
||||
int getClassNumStates();
|
||||
string getClassName();
|
||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||
void estimateParameters();
|
||||
void setRoot(string);
|
||||
Node* getRoot();
|
||||
vector<int> predict(const vector<vector<int>>&);
|
||||
|
Loading…
Reference in New Issue
Block a user