Add threads to fit
This commit is contained in:
parent
b6c21c21e2
commit
0b33c6c04a
@ -221,7 +221,7 @@ int main(int argc, char** argv)
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "Class name: " << className << endl;
|
cout << "Class name: " << className << endl;
|
||||||
// Build Network
|
// Build Network
|
||||||
auto network = bayesnet::Network();
|
auto network = bayesnet::Network(1.0);
|
||||||
build_network(network, network_name, maxes);
|
build_network(network, network_name, maxes);
|
||||||
network.fit(Xd, y, features, className);
|
network.fit(Xd, y, features, className);
|
||||||
cout << "Hello, Bayesian Networks!" << endl;
|
cout << "Hello, Bayesian Networks!" << endl;
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include "Network.h"
|
#include "Network.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
Network::Network() : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(0.8) {}
|
||||||
Network::Network(int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0) {}
|
Network::Network(float maxT) : laplaceSmoothing(1), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||||
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates())
|
Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), root(nullptr), features(vector<string>()), className(""), classNumStates(0), maxThreads(maxT) {}
|
||||||
|
Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), root(other.root), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads())
|
||||||
{
|
{
|
||||||
for (auto& pair : other.nodes) {
|
for (auto& pair : other.nodes) {
|
||||||
nodes[pair.first] = new Node(*pair.second);
|
nodes[pair.first] = new Node(*pair.second);
|
||||||
@ -16,6 +17,10 @@ namespace bayesnet {
|
|||||||
delete pair.second;
|
delete pair.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
float Network::getmaxThreads()
|
||||||
|
{
|
||||||
|
return maxThreads;
|
||||||
|
}
|
||||||
void Network::addNode(string name, int numStates)
|
void Network::addNode(string name, int numStates)
|
||||||
{
|
{
|
||||||
if (nodes.find(name) != nodes.end()) {
|
if (nodes.find(name) != nodes.end()) {
|
||||||
@ -93,24 +98,62 @@ namespace bayesnet {
|
|||||||
{
|
{
|
||||||
return nodes;
|
return nodes;
|
||||||
}
|
}
|
||||||
void Network::fit(const vector<vector<int>>& dataset, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
|
||||||
{
|
{
|
||||||
features = featureNames;
|
features = featureNames;
|
||||||
this->className = className;
|
this->className = className;
|
||||||
|
dataset.clear();
|
||||||
|
|
||||||
// Build dataset
|
// Build dataset
|
||||||
for (int i = 0; i < featureNames.size(); ++i) {
|
for (int i = 0; i < featureNames.size(); ++i) {
|
||||||
this->dataset[featureNames[i]] = dataset[i];
|
dataset[featureNames[i]] = input_data[i];
|
||||||
}
|
}
|
||||||
this->dataset[className] = labels;
|
dataset[className] = labels;
|
||||||
this->classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
||||||
estimateParameters();
|
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
||||||
}
|
if (maxThreadsRunning < 1) {
|
||||||
|
maxThreadsRunning = 1;
|
||||||
|
}
|
||||||
|
cout << "Using " << maxThreadsRunning << " threads" << " maxThreads: " << maxThreads << endl;
|
||||||
|
vector<thread> threads;
|
||||||
|
mutex mtx;
|
||||||
|
condition_variable cv;
|
||||||
|
int activeThreads = 0;
|
||||||
|
int nextNodeIndex = 0;
|
||||||
|
|
||||||
void Network::estimateParameters()
|
while (nextNodeIndex < nodes.size()) {
|
||||||
{
|
unique_lock<mutex> lock(mtx);
|
||||||
auto dimensions = vector<int64_t>();
|
cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; });
|
||||||
for (auto [name, node] : nodes) {
|
|
||||||
node->computeCPT(dataset, laplaceSmoothing);
|
if (nextNodeIndex >= nodes.size()) {
|
||||||
|
break; // No more work remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() {
|
||||||
|
while (true) {
|
||||||
|
unique_lock<mutex> lock(mtx);
|
||||||
|
if (nextNodeIndex >= nodes.size()) {
|
||||||
|
break; // No more work remaining
|
||||||
|
}
|
||||||
|
auto& pair = *std::next(nodes.begin(), nextNodeIndex);
|
||||||
|
++nextNodeIndex;
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
pair.second->computeCPT(dataset, laplaceSmoothing);
|
||||||
|
|
||||||
|
lock.lock();
|
||||||
|
nodes[pair.first] = pair.second;
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
lock_guard<mutex> lock(mtx);
|
||||||
|
--activeThreads;
|
||||||
|
cv.notify_one();
|
||||||
|
});
|
||||||
|
|
||||||
|
++activeThreads;
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +236,7 @@ namespace bayesnet {
|
|||||||
|
|
||||||
lock_guard<mutex> lock(mtx);
|
lock_guard<mutex> lock(mtx);
|
||||||
result[i] = factor;
|
result[i] = factor;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
for (auto& thread : threads) {
|
||||||
@ -205,7 +248,6 @@ namespace bayesnet {
|
|||||||
for (double& value : result) {
|
for (double& value : result) {
|
||||||
value /= sum;
|
value /= sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ namespace bayesnet {
|
|||||||
map<string, Node*> nodes;
|
map<string, Node*> nodes;
|
||||||
map<string, vector<int>> dataset;
|
map<string, vector<int>> dataset;
|
||||||
Node* root;
|
Node* root;
|
||||||
|
float maxThreads;
|
||||||
int classNumStates;
|
int classNumStates;
|
||||||
vector<string> features;
|
vector<string> features;
|
||||||
string className;
|
string className;
|
||||||
@ -21,9 +22,11 @@ namespace bayesnet {
|
|||||||
double computeFactor(map<string, int>&);
|
double computeFactor(map<string, int>&);
|
||||||
public:
|
public:
|
||||||
Network();
|
Network();
|
||||||
Network(int);
|
Network(float, int);
|
||||||
|
Network(float);
|
||||||
Network(Network&);
|
Network(Network&);
|
||||||
~Network();
|
~Network();
|
||||||
|
float getmaxThreads();
|
||||||
void addNode(string, int);
|
void addNode(string, int);
|
||||||
void addEdge(const string, const string);
|
void addEdge(const string, const string);
|
||||||
map<string, Node*>& getNodes();
|
map<string, Node*>& getNodes();
|
||||||
@ -31,7 +34,6 @@ namespace bayesnet {
|
|||||||
int getClassNumStates();
|
int getClassNumStates();
|
||||||
string getClassName();
|
string getClassName();
|
||||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||||
void estimateParameters();
|
|
||||||
void setRoot(string);
|
void setRoot(string);
|
||||||
Node* getRoot();
|
Node* getRoot();
|
||||||
vector<int> predict(const vector<vector<int>>&);
|
vector<int> predict(const vector<vector<int>>&);
|
||||||
|
Loading…
Reference in New Issue
Block a user