diff --git a/src/Network.cc b/src/Network.cc index 458eadb..391b8af 100644 --- a/src/Network.cc +++ b/src/Network.cc @@ -151,13 +151,17 @@ namespace bayesnet { for (int col = 0; col < samples.size(); ++col) { sample.push_back(samples[col][row]); } - predictions.push_back(predict_sample(sample).first); + vector classProbabilities = predict_sample(sample); + // Find the class with the maximum posterior probability + auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); + int predictedClass = distance(classProbabilities.begin(), maxElem); + predictions.push_back(predictedClass); } return predictions; } - vector> Network::predict_proba(const vector>& samples) + vector> Network::predict_proba(const vector>& samples) { - vector> predictions; + vector> predictions; vector sample; for (int row = 0; row < samples[0].size(); ++row) { sample.clear(); @@ -179,7 +183,7 @@ namespace bayesnet { } return (double)correct / y_pred.size(); } - pair Network::predict_sample(const vector& sample) + vector Network::predict_sample(const vector& sample) { // Ensure the sample size is equal to the number of features if (sample.size() != features.size()) { @@ -190,14 +194,8 @@ namespace bayesnet { for (int i = 0; i < sample.size(); ++i) { evidence[features[i]] = sample[i]; } - vector classProbabilities = exactInference(evidence); + return exactInference(evidence); - // Find the class with the maximum posterior probability - auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); - int predictedClass = distance(classProbabilities.begin(), maxElem); - double maxProbability = *maxElem; - - return make_pair(predictedClass, maxProbability); } double Network::computeFactor(map& completeEvidence) { diff --git a/src/Network.h b/src/Network.h index db4a1e3..78ce3ab 100644 --- a/src/Network.h +++ b/src/Network.h @@ -16,7 +16,7 @@ namespace bayesnet { string className; int laplaceSmoothing; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); - pair predict_sample(const vector&); + vector predict_sample(const vector&); vector exactInference(map&); double computeFactor(map&); public: @@ -34,7 +34,7 @@ namespace bayesnet { string getClassName(); void fit(const vector>&, const vector&, const vector&, const string&); vector predict(const vector>&); - vector> predict_proba(const vector>&); + vector> predict_proba(const vector>&); double score(const vector>&, const vector&); inline string version() { return "0.1.0"; } };