diff --git a/src/Network.cc b/src/Network.cc index 9f926b3..7239252 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -58,14 +58,12 @@ namespace bayesnet { // Temporarily add edge to check for cycles nodes[parent]->addChild(nodes[child]); nodes[child]->addParent(nodes[parent]); - // temporarily add edge unordered_set visited; unordered_set recStack; if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle { // remove problematic edge nodes[parent]->removeChild(nodes[child]); - nodes[child]->removeParent(nodes[parent]); throw invalid_argument("Adding this edge forms a cycle in the graph."); } @@ -116,47 +114,37 @@ namespace bayesnet { node->setCPT(cpt); } } - pair Network::predict_sample(const vector& sample) - { - // Ensure the sample size is equal to the number of features - if (sample.size() != features.size()) { - throw std::invalid_argument("Sample size (" + to_string(sample.size()) + - ") does not match the number of features (" + to_string(features.size()) + ")"); - } + // pair Network::predict_sample(const vector& sample) + // { - // Map the feature values to their corresponding nodes - map featureValues; - for (int i = 0; i < features.size(); ++i) { - featureValues[features[i]] = sample[i]; - } - // For each possible class, calculate the posterior probability - Node* classNode = nodes[className]; - int numClassStates = classNode->getNumStates(); - std::vector classProbabilities(numClassStates, 0.0); - for (int classState = 0; classState < numClassStates; ++classState) { - // Start with the prior probability of the class - classProbabilities[classState] = classNode->getCPT()[classState].item(); + // // For each possible class, calculate the posterior probability + // Node* classNode = nodes[className]; + // int numClassStates = classNode->getNumStates(); + // vector classProbabilities(numClassStates, 0.0); + // for (int classState = 0; classState < numClassStates; ++classState) { + // // Start with the prior probability of the class + // classProbabilities[classState] = classNode->getCPT()[classState].item(); - // Multiply by the likelihood of each feature given the class - for (auto& pair : nodes) { - if (pair.first != className) { - Node* node = pair.second; - int featureValue = featureValues[pair.first]; + // // Multiply by the likelihood of each feature given the class + // for (auto& pair : nodes) { + // if (pair.first != className) { + // Node* node = pair.second; + // int featureValue = featureValues[pair.first]; - // We use the class as the parent state to index into the CPT - classProbabilities[classState] *= node->getCPT()[classState][featureValue].item(); - } - } - } + // // We use the class as the parent state to index into the CPT + // classProbabilities[classState] *= node->getCPT()[classState][featureValue].item(); + // } + // } + // } - // Find the class with the maximum posterior probability - auto maxElem = std::max_element(classProbabilities.begin(), classProbabilities.end()); - int predictedClass = std::distance(classProbabilities.begin(), maxElem); - double maxProbability = *maxElem; + // // Find the class with the maximum posterior probability + // auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); + // int predictedClass = distance(classProbabilities.begin(), maxElem); + // double maxProbability = *maxElem; - return std::make_pair(predictedClass, maxProbability); - } + // return make_pair(predictedClass, maxProbability); + // } vector Network::predict(const vector>& samples) { vector predictions; @@ -194,4 +182,37 @@ namespace bayesnet { } return (double)correct / y_pred.size(); } + pair Network::predict_sample(const vector& sample) + { + // Ensure the sample size is equal to the number of features + if (sample.size() != features.size()) { + throw invalid_argument("Sample size (" + to_string(sample.size()) + + ") does not match the number of features (" + to_string(features.size()) + ")"); + } + // Map the feature values to their corresponding nodes + map featureValues; + for (int i = 0; i < features.size(); ++i) { + featureValues[features[i]] = sample[i]; + } + + // For each possible class, calculate the posterior probability + Network network = *this; + vector classProbabilities = eliminateVariables(network, featureValues); + + // Normalize the probabilities to sum to 1 + double sum = accumulate(classProbabilities.begin(), classProbabilities.end(), 0.0); + for (double& prob : classProbabilities) { + prob /= sum; + } + // Find the class with the maximum posterior probability + auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); + int predictedClass = distance(classProbabilities.begin(), maxElem); + double maxProbability = *maxElem; + + return make_pair(predictedClass, maxProbability); + } + vector eliminateVariables(network, featureValues) + { + + } } diff --git a/src/Network.h b/src/Network.h index a155fed..1c95fa5 100644 --- a/src/Network.h +++ b/src/Network.h @@ -16,6 +16,7 @@ namespace bayesnet { int laplaceSmoothing; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); pair predict_sample(const vector&); + vector eliminateVariables(Network&, const map&); public: Network(); Network(int);