#include #include #include "Network.h" namespace bayesnet { Network::Network() : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(0.8) {} Network::Network(float maxT) : laplaceSmoothing(1), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(float maxT, int smoothing) : laplaceSmoothing(smoothing), features(vector()), className(""), classNumStates(0), maxThreads(maxT) {} Network::Network(Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), maxThreads(other.getmaxThreads()) { for (auto& pair : other.nodes) { nodes[pair.first] = new Node(*pair.second); } } Network::~Network() { for (auto& pair : nodes) { delete pair.second; } } float Network::getmaxThreads() { return maxThreads; } torch::Tensor& Network::getSamples() { return samples; } void Network::addNode(string name, int numStates) { if (nodes.find(name) != nodes.end()) { // if node exists update its number of states nodes[name]->setNumStates(numStates); return; } nodes[name] = new Node(name, numStates); } vector Network::getFeatures() { return features; } int Network::getClassNumStates() { return classNumStates; } int Network::getStates() { int result = 0; for (auto node : nodes) { result += node.second->getNumStates(); } return result; } string Network::getClassName() { return className; } bool Network::isCyclic(const string& nodeId, unordered_set& visited, unordered_set& recStack) { if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet { visited.insert(nodeId); recStack.insert(nodeId); for (Node* child : nodes[nodeId]->getChildren()) { if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack)) return true; else if (recStack.find(child->getName()) != recStack.end()) return true; } } recStack.erase(nodeId); // remove node from recursion stack before function ends return false; } void Network::addEdge(const string parent, const string child) { if (nodes.find(parent) == nodes.end()) { throw invalid_argument("Parent node " + parent + " does not exist"); } if (nodes.find(child) == nodes.end()) { throw invalid_argument("Child node " + child + " does not exist"); } // Temporarily add edge to check for cycles nodes[parent]->addChild(nodes[child]); nodes[child]->addParent(nodes[parent]); unordered_set visited; unordered_set recStack; if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle { // remove problematic edge nodes[parent]->removeChild(nodes[child]); nodes[child]->removeParent(nodes[parent]); throw invalid_argument("Adding this edge forms a cycle in the graph."); } } map& Network::getNodes() { return nodes; } void Network::fit(const vector>& input_data, const vector& labels, const vector& featureNames, const string& className) { features = featureNames; this->className = className; dataset.clear(); // Build dataset & tensor of samples samples = torch::zeros({ static_cast(input_data[0].size()), static_cast(input_data.size() + 1) }, torch::kInt64); for (int i = 0; i < featureNames.size(); ++i) { dataset[featureNames[i]] = input_data[i]; samples.index_put_({ "...", i }, torch::tensor(input_data[i], torch::kInt64)); } dataset[className] = labels; samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt64)); classNumStates = *max_element(labels.begin(), labels.end()) + 1; int maxThreadsRunning = static_cast(std::thread::hardware_concurrency() * maxThreads); if (maxThreadsRunning < 1) { maxThreadsRunning = 1; } vector threads; mutex mtx; condition_variable cv; int activeThreads = 0; int nextNodeIndex = 0; while (nextNodeIndex < nodes.size()) { unique_lock lock(mtx); cv.wait(lock, [&activeThreads, &maxThreadsRunning]() { return activeThreads < maxThreadsRunning; }); if (nextNodeIndex >= nodes.size()) { break; // No more work remaining } threads.emplace_back([this, &nextNodeIndex, &mtx, &cv, &activeThreads]() { while (true) { unique_lock lock(mtx); if (nextNodeIndex >= nodes.size()) { break; // No more work remaining } auto& pair = *std::next(nodes.begin(), nextNodeIndex); ++nextNodeIndex; lock.unlock(); pair.second->computeCPT(dataset, laplaceSmoothing); lock.lock(); nodes[pair.first] = pair.second; lock.unlock(); } lock_guard lock(mtx); --activeThreads; cv.notify_one(); }); ++activeThreads; } for (auto& thread : threads) { thread.join(); } } vector Network::predict(const vector>& tsamples) { vector predictions; vector sample; for (int row = 0; row < tsamples[0].size(); ++row) { sample.clear(); for (int col = 0; col < tsamples.size(); ++col) { sample.push_back(tsamples[col][row]); } vector classProbabilities = predict_sample(sample); // Find the class with the maximum posterior probability auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); int predictedClass = distance(classProbabilities.begin(), maxElem); predictions.push_back(predictedClass); } return predictions; } vector> Network::predict_proba(const vector>& tsamples) { vector> predictions; vector sample; for (int row = 0; row < tsamples[0].size(); ++row) { sample.clear(); for (int col = 0; col < tsamples.size(); ++col) { sample.push_back(tsamples[col][row]); } predictions.push_back(predict_sample(sample)); } return predictions; } double Network::score(const vector>& tsamples, const vector& labels) { vector y_pred = predict(tsamples); int correct = 0; for (int i = 0; i < y_pred.size(); ++i) { if (y_pred[i] == labels[i]) { correct++; } } return (double)correct / y_pred.size(); } vector Network::predict_sample(const vector& sample) { // Ensure the sample size is equal to the number of features if (sample.size() != features.size()) { throw invalid_argument("Sample size (" + to_string(sample.size()) + ") does not match the number of features (" + to_string(features.size()) + ")"); } map evidence; for (int i = 0; i < sample.size(); ++i) { evidence[features[i]] = sample[i]; } return exactInference(evidence); } double Network::computeFactor(map& completeEvidence) { double result = 1.0; for (auto node : getNodes()) { result *= node.second->getFactorValue(completeEvidence); } return result; } vector Network::exactInference(map& evidence) { vector result(classNumStates, 0.0); vector threads; mutex mtx; for (int i = 0; i < classNumStates; ++i) { threads.emplace_back([this, &result, &evidence, i, &mtx]() { auto completeEvidence = map(evidence); completeEvidence[getClassName()] = i; double factor = computeFactor(completeEvidence); lock_guard lock(mtx); result[i] = factor; }); } for (auto& thread : threads) { thread.join(); } // Normalize result double sum = accumulate(result.begin(), result.end(), 0.0); for (double& value : result) { value /= sum; } return result; } vector Network::show() { vector result; // Draw the network for (auto node : nodes) { string line = node.first + " -> "; for (auto child : node.second->getChildren()) { line += child->getName() + ", "; } result.push_back(line); } return result; } }