Add Notes to Proposal convergence

This commit is contained in:
2025-07-08 18:50:09 +02:00
parent aa77745e55
commit e2a0c5f4a5
7 changed files with 13 additions and 9 deletions

View File

@@ -8,7 +8,7 @@
#include <memory> #include <memory>
namespace bayesnet { namespace bayesnet {
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className) KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className, KDB::notes)
{ {
validHyperparameters = validHyperparameters_ld; validHyperparameters = validHyperparameters_ld;
validHyperparameters.push_back("k"); validHyperparameters.push_back("k");

View File

@@ -16,7 +16,7 @@
#include "TANLd.h" #include "TANLd.h"
namespace bayesnet { namespace bayesnet {
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_) Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_) : pDataset(dataset_), pFeatures(features_), pClassName(className_), notes(notes_)
{ {
} }
void Proposal::setHyperparameters(nlohmann::json& hyperparameters) void Proposal::setHyperparameters(nlohmann::json& hyperparameters)
@@ -215,6 +215,8 @@ namespace bayesnet {
if (convergence_params.verbose) { if (convergence_params.verbose) {
std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl; std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl;
} }
notes.push_back("Converged after " + std::to_string(iteration + 1) + " of "
+ std::to_string(convergence_params.maxIterations) + " iterations");
break; break;
} }

View File

@@ -18,7 +18,7 @@
namespace bayesnet { namespace bayesnet {
class Proposal { class Proposal {
public: public:
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_); Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes);
void setHyperparameters(nlohmann::json& hyperparameters_); void setHyperparameters(nlohmann::json& hyperparameters_);
protected: protected:
void checkInput(const torch::Tensor& X, const torch::Tensor& y); void checkInput(const torch::Tensor& X, const torch::Tensor& y);
@@ -61,6 +61,7 @@ namespace bayesnet {
}; };
private: private:
std::vector<int> factorize(const std::vector<std::string>& labels_t); std::vector<int> factorize(const std::vector<std::string>& labels_t);
std::vector<std::string>& notes; // Notes during fit from BaseClassifier
torch::Tensor& pDataset; // (n+1)xm tensor torch::Tensor& pDataset; // (n+1)xm tensor
std::vector<std::string>& pFeatures; std::vector<std::string>& pFeatures;
std::string& pClassName; std::string& pClassName;

View File

@@ -7,7 +7,7 @@
#include "SPODELd.h" #include "SPODELd.h"
namespace bayesnet { namespace bayesnet {
SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className) SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className, SPODE::notes)
{ {
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
} }

View File

@@ -8,7 +8,7 @@
#include <memory> #include <memory>
namespace bayesnet { namespace bayesnet {
TANLd::TANLd() : TAN(), Proposal(dataset, features, className) TANLd::TANLd() : TAN(), Proposal(dataset, features, className, TAN::notes)
{ {
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
} }

View File

@@ -7,7 +7,7 @@
#include "AODELd.h" #include "AODELd.h"
namespace bayesnet { namespace bayesnet {
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className) AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className, Ensemble::notes)
{ {
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
} }

View File

@@ -407,14 +407,15 @@ TEST_CASE("Check proposal checkInput", "[Models]")
{ {
class testProposal : public bayesnet::Proposal { class testProposal : public bayesnet::Proposal {
public: public:
testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_)
: Proposal(dataset_, features_, className_) : Proposal(dataset_, features_, className_, notes_)
{ {
} }
void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); } void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); }
}; };
auto raw = RawDatasets("iris", true); auto raw = RawDatasets("iris", true);
auto clf = testProposal(raw.dataset, raw.features, raw.className); std::vector<std::string> notes;
auto clf = testProposal(raw.dataset, raw.features, raw.className, notes);
torch::Tensor X = torch::randint(0, 3, { 10, 4 }); torch::Tensor X = torch::randint(0, 3, { 10, 4 });
torch::Tensor y = torch::rand({ 10 }); torch::Tensor y = torch::rand({ 10 });
INFO("Check X is not float"); INFO("Check X is not float");