diff --git a/bayesnet/classifiers/KDBLd.cc b/bayesnet/classifiers/KDBLd.cc index 1b96bce..dbda394 100644 --- a/bayesnet/classifiers/KDBLd.cc +++ b/bayesnet/classifiers/KDBLd.cc @@ -8,7 +8,7 @@ #include namespace bayesnet { - KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className) + KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className, KDB::notes) { validHyperparameters = validHyperparameters_ld; validHyperparameters.push_back("k"); diff --git a/bayesnet/classifiers/Proposal.cc b/bayesnet/classifiers/Proposal.cc index b634b6e..b3c7639 100644 --- a/bayesnet/classifiers/Proposal.cc +++ b/bayesnet/classifiers/Proposal.cc @@ -16,7 +16,7 @@ #include "TANLd.h" namespace bayesnet { - Proposal::Proposal(torch::Tensor& dataset_, std::vector& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_) + Proposal::Proposal(torch::Tensor& dataset_, std::vector& features_, std::string& className_, std::vector& notes_) : pDataset(dataset_), pFeatures(features_), pClassName(className_), notes(notes_) { } void Proposal::setHyperparameters(nlohmann::json& hyperparameters) @@ -215,6 +215,8 @@ namespace bayesnet { if (convergence_params.verbose) { std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl; } + notes.push_back("Converged after " + std::to_string(iteration + 1) + " of " + + std::to_string(convergence_params.maxIterations) + " iterations"); break; } diff --git a/bayesnet/classifiers/Proposal.h b/bayesnet/classifiers/Proposal.h index 9f23283..bb53776 100644 --- a/bayesnet/classifiers/Proposal.h +++ b/bayesnet/classifiers/Proposal.h @@ -18,7 +18,7 @@ namespace bayesnet { class Proposal { public: - Proposal(torch::Tensor& pDataset, std::vector& features_, std::string& className_); + Proposal(torch::Tensor& pDataset, std::vector& features_, std::string& className_, std::vector& notes); void setHyperparameters(nlohmann::json& hyperparameters_); protected: void checkInput(const torch::Tensor& X, const torch::Tensor& y); @@ -61,6 +61,7 @@ namespace bayesnet { }; private: std::vector factorize(const std::vector& labels_t); + std::vector& notes; // Notes during fit from BaseClassifier torch::Tensor& pDataset; // (n+1)xm tensor std::vector& pFeatures; std::string& pClassName; diff --git a/bayesnet/classifiers/SPODELd.cc b/bayesnet/classifiers/SPODELd.cc index 8cdbdec..0cffe63 100644 --- a/bayesnet/classifiers/SPODELd.cc +++ b/bayesnet/classifiers/SPODELd.cc @@ -7,7 +7,7 @@ #include "SPODELd.h" namespace bayesnet { - SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className) + SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className, SPODE::notes) { validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal } diff --git a/bayesnet/classifiers/TANLd.cc b/bayesnet/classifiers/TANLd.cc index b415b0f..c9de329 100644 --- a/bayesnet/classifiers/TANLd.cc +++ b/bayesnet/classifiers/TANLd.cc @@ -8,7 +8,7 @@ #include namespace bayesnet { - TANLd::TANLd() : TAN(), Proposal(dataset, features, className) + TANLd::TANLd() : TAN(), Proposal(dataset, features, className, TAN::notes) { validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal } diff --git a/bayesnet/ensembles/AODELd.cc b/bayesnet/ensembles/AODELd.cc index 3dc80bf..4f0f0cd 100644 --- a/bayesnet/ensembles/AODELd.cc +++ b/bayesnet/ensembles/AODELd.cc @@ -7,7 +7,7 @@ #include "AODELd.h" namespace bayesnet { - AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className) + AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className, Ensemble::notes) { validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal } diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index 26cd773..4473c35 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -407,14 +407,15 @@ TEST_CASE("Check proposal checkInput", "[Models]") { class testProposal : public bayesnet::Proposal { public: - testProposal(torch::Tensor& dataset_, std::vector& features_, std::string& className_) - : Proposal(dataset_, features_, className_) + testProposal(torch::Tensor& dataset_, std::vector& features_, std::string& className_, std::vector& notes_) + : Proposal(dataset_, features_, className_, notes_) { } void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); } }; auto raw = RawDatasets("iris", true); - auto clf = testProposal(raw.dataset, raw.features, raw.className); + std::vector notes; + auto clf = testProposal(raw.dataset, raw.features, raw.className, notes); torch::Tensor X = torch::randint(0, 3, { 10, 4 }); torch::Tensor y = torch::rand({ 10 }); INFO("Check X is not float");