diff --git a/bayesnet/classifiers/XSPODE.cc b/bayesnet/classifiers/XSPODE.cc index ff7d332..ab0bf69 100644 --- a/bayesnet/classifiers/XSPODE.cc +++ b/bayesnet/classifiers/XSPODE.cc @@ -407,5 +407,24 @@ namespace bayesnet { } return result; } + float XSpode::score(torch::Tensor& X, torch::Tensor& y) + { + torch::Tensor y_pred = predict(X); + return (y_pred == y).sum().item() / y.size(0); + } + float XSpode::score(std::vector>& X, std::vector& y) + { + if (!fitted) { + throw std::logic_error(CLASSIFIER_NOT_FITTED); + } + auto y_pred = this->predict(X); + int correct = 0; + for (int i = 0; i < y_pred.size(); ++i) { + if (y_pred[i] == y[i]) { + correct++; + } + } + return (double)correct / y_pred.size(); + } } diff --git a/bayesnet/classifiers/XSPODE.h b/bayesnet/classifiers/XSPODE.h index 215e8ab..a8c24af 100644 --- a/bayesnet/classifiers/XSPODE.h +++ b/bayesnet/classifiers/XSPODE.h @@ -38,6 +38,8 @@ namespace bayesnet { torch::Tensor predict(torch::Tensor& X) override; std::vector predict(std::vector>& X) override; torch::Tensor predict_proba(torch::Tensor& X) override; + float score(torch::Tensor& X, torch::Tensor& y) override; + float score(std::vector>& X, std::vector& y) override; protected: void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; diff --git a/bayesnet/ensembles/WA2DE.h b/bayesnet/ensembles/WA2DE.h index 7008025..246ce7b 100644 --- a/bayesnet/ensembles/WA2DE.h +++ b/bayesnet/ensembles/WA2DE.h @@ -30,7 +30,7 @@ namespace bayesnet { protected: // Model-building function void buildModel(const torch::Tensor& weights) override; - + void trainModel(const torch::Tensor& data, const Smoothing_t smoothing) override; private: int num_classes_; // Number of classes int num_attributes_; // Number of attributes @@ -46,7 +46,6 @@ namespace bayesnet { bool weighted_a2de_; // Whether to use weighted A2DE double smoothing_factor_; // Smoothing parameter (default: Laplace) torch::Tensor AODEConditionalProb(const torch::Tensor& data); - void trainModel(const torch::Tensor& data, const Smoothing_t smoothing); int toIntValue(int attributeIndex, float value) const; }; } diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index fbcfdcc..d3e79a4 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -22,4 +22,4 @@ include_directories( ) add_executable(bayesnet_sample sample.cc) -target_link_libraries(bayesnet_sample fimdlp "${TORCH_LIBRARIES}" "${BayesNet}") \ No newline at end of file +target_link_libraries(bayesnet_sample ${FImdlp} "${TORCH_LIBRARIES}" "${BayesNet}") \ No newline at end of file