From b33da346553fd91fe97eb792cc54b2a9151d3a9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Thu, 8 Feb 2024 18:01:09 +0100 Subject: [PATCH] Add notes to Classifier & use them in BoostAODE --- src/BayesNet/BoostAODE.cc | 2 ++ src/BayesNet/Classifier.h | 1 + 2 files changed, 3 insertions(+) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 8178280..e2083f0 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -115,6 +115,7 @@ namespace bayesnet { significanceModels.push_back(1.0); n_models++; } + notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + algorithm); delete featureSelector; return featuresUsed; } @@ -197,6 +198,7 @@ namespace bayesnet { exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; } if (featuresUsed.size() != features.size()) { + notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size())); status = WARNING; } } diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 4bd2c57..9db25d9 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -19,6 +19,7 @@ namespace bayesnet { std::map> states; torch::Tensor dataset; // (n+1)xm tensor status_t status = NORMAL; + std::vector notes; // Used to store messages occurred during the fit process void checkFitParameters(); virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override;