Add Notes to Proposal convergence
This commit is contained in:
@@ -8,7 +8,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className)
|
KDBLd::KDBLd(int k) : KDB(k), Proposal(dataset, features, className, KDB::notes)
|
||||||
{
|
{
|
||||||
validHyperparameters = validHyperparameters_ld;
|
validHyperparameters = validHyperparameters_ld;
|
||||||
validHyperparameters.push_back("k");
|
validHyperparameters.push_back("k");
|
||||||
|
@@ -16,7 +16,7 @@
|
|||||||
#include "TANLd.h"
|
#include "TANLd.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_) : pDataset(dataset_), pFeatures(features_), pClassName(className_)
|
Proposal::Proposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_) : pDataset(dataset_), pFeatures(features_), pClassName(className_), notes(notes_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
void Proposal::setHyperparameters(nlohmann::json& hyperparameters)
|
void Proposal::setHyperparameters(nlohmann::json& hyperparameters)
|
||||||
@@ -215,6 +215,8 @@ namespace bayesnet {
|
|||||||
if (convergence_params.verbose) {
|
if (convergence_params.verbose) {
|
||||||
std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl;
|
std::cout << "Converged after " << (iteration + 1) << " iterations" << std::endl;
|
||||||
}
|
}
|
||||||
|
notes.push_back("Converged after " + std::to_string(iteration + 1) + " of "
|
||||||
|
+ std::to_string(convergence_params.maxIterations) + " iterations");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class Proposal {
|
class Proposal {
|
||||||
public:
|
public:
|
||||||
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_);
|
Proposal(torch::Tensor& pDataset, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes);
|
||||||
void setHyperparameters(nlohmann::json& hyperparameters_);
|
void setHyperparameters(nlohmann::json& hyperparameters_);
|
||||||
protected:
|
protected:
|
||||||
void checkInput(const torch::Tensor& X, const torch::Tensor& y);
|
void checkInput(const torch::Tensor& X, const torch::Tensor& y);
|
||||||
@@ -61,6 +61,7 @@ namespace bayesnet {
|
|||||||
};
|
};
|
||||||
private:
|
private:
|
||||||
std::vector<int> factorize(const std::vector<std::string>& labels_t);
|
std::vector<int> factorize(const std::vector<std::string>& labels_t);
|
||||||
|
std::vector<std::string>& notes; // Notes during fit from BaseClassifier
|
||||||
torch::Tensor& pDataset; // (n+1)xm tensor
|
torch::Tensor& pDataset; // (n+1)xm tensor
|
||||||
std::vector<std::string>& pFeatures;
|
std::vector<std::string>& pFeatures;
|
||||||
std::string& pClassName;
|
std::string& pClassName;
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
#include "SPODELd.h"
|
#include "SPODELd.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className)
|
SPODELd::SPODELd(int root) : SPODE(root), Proposal(dataset, features, className, SPODE::notes)
|
||||||
{
|
{
|
||||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
TANLd::TANLd() : TAN(), Proposal(dataset, features, className)
|
TANLd::TANLd() : TAN(), Proposal(dataset, features, className, TAN::notes)
|
||||||
{
|
{
|
||||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
#include "AODELd.h"
|
#include "AODELd.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
|
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className, Ensemble::notes)
|
||||||
{
|
{
|
||||||
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
validHyperparameters = validHyperparameters_ld; // Inherits the valid hyperparameters from Proposal
|
||||||
}
|
}
|
||||||
|
@@ -407,14 +407,15 @@ TEST_CASE("Check proposal checkInput", "[Models]")
|
|||||||
{
|
{
|
||||||
class testProposal : public bayesnet::Proposal {
|
class testProposal : public bayesnet::Proposal {
|
||||||
public:
|
public:
|
||||||
testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_)
|
testProposal(torch::Tensor& dataset_, std::vector<std::string>& features_, std::string& className_, std::vector<std::string>& notes_)
|
||||||
: Proposal(dataset_, features_, className_)
|
: Proposal(dataset_, features_, className_, notes_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); }
|
void test_X_y(const torch::Tensor& X, const torch::Tensor& y) { checkInput(X, y); }
|
||||||
};
|
};
|
||||||
auto raw = RawDatasets("iris", true);
|
auto raw = RawDatasets("iris", true);
|
||||||
auto clf = testProposal(raw.dataset, raw.features, raw.className);
|
std::vector<std::string> notes;
|
||||||
|
auto clf = testProposal(raw.dataset, raw.features, raw.className, notes);
|
||||||
torch::Tensor X = torch::randint(0, 3, { 10, 4 });
|
torch::Tensor X = torch::randint(0, 3, { 10, 4 });
|
||||||
torch::Tensor y = torch::rand({ 10 });
|
torch::Tensor y = torch::rand({ 10 });
|
||||||
INFO("Check X is not float");
|
INFO("Check X is not float");
|
||||||
|
Reference in New Issue
Block a user