Begin using validation as finish condition

This commit is contained in:
Ricardo Montañana Gómez 2023-09-06 10:51:07 +02:00
parent 5a7c8f1818
commit d908f389f5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 34 additions and 10 deletions

View File

@ -2,6 +2,7 @@
#include <set> #include <set>
#include "BayesMetrics.h" #include "BayesMetrics.h"
#include "Colors.h" #include "Colors.h"
#include "Folding.h"
namespace bayesnet { namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble() {} BoostAODE::BoostAODE() : Ensemble() {}
@ -24,15 +25,35 @@ namespace bayesnet {
ascending = hyperparameters["ascending"]; ascending = hyperparameters["ascending"];
} }
} }
void BoostAODE::validationInit()
{
auto y_ = dataset.index({ -1, "..." });
auto fold = platform::StratifiedKFold(5, y_, 271);
// save input dataset
dataset_ = torch::clone(dataset);
auto [train, test] = fold.getFold(0);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
// Get train and validation sets
X_train = dataset.index({ "...", train_t });
y_train = dataset.index({ -1, train_t });
X_test = dataset.index({ "...", test_t });
y_test = dataset.index({ -1, test_t });
// Build dataset with train data
dataset = X_train;
buildDataset(y_train);
m = X_train.size(1);
auto n_classes = states.at(className).size();
metrics = Metrics(dataset, features, className, n_classes);
}
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
models.clear(); models.clear();
n_models = 0; n_models = 0;
if (maxModels == 0) if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n; maxModels = .1 * n > 10 ? .1 * n : n;
validationInit();
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
auto y_ = dataset.index({ -1, "..." });
bool exitCondition = false; bool exitCondition = false;
unordered_set<int> featuresUsed; unordered_set<int> featuresUsed;
// Step 0: Set the finish condition // Step 0: Set the finish condition
@ -62,10 +83,10 @@ namespace bayesnet {
model = std::make_unique<SPODE>(feature); model = std::make_unique<SPODE>(feature);
n_models++; n_models++;
model->fit(dataset, features, className, states, weights_); model->fit(dataset, features, className, states, weights_);
auto ypred = model->predict(X_); auto ypred = model->predict(X_train);
// Step 3.1: Compute the classifier amout of say // Step 3.1: Compute the classifier amout of say
auto mask_wrong = ypred != y_; auto mask_wrong = ypred != y_train;
auto mask_right = ypred == y_; auto mask_right = ypred == y_train;
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype()); auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
double epsilon_t = masked_weights.sum().item<double>(); double epsilon_t = masked_weights.sum().item<double>();
double wt = (1 - epsilon_t) / epsilon_t; double wt = (1 - epsilon_t) / epsilon_t;

View File

@ -13,9 +13,12 @@ namespace bayesnet {
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override;
private: private:
bool repeatSparent=false; torch::Tensor dataset_;
int maxModels=0; torch::Tensor X_train, y_train, X_test, y_test;
bool ascending=false; //Process KBest features ascending or descending order void validationInit();
bool repeatSparent = false;
int maxModels = 0;
bool ascending = false; //Process KBest features ascending or descending order
}; };
} }
#endif #endif

View File

@ -75,7 +75,7 @@ namespace bayesnet {
throw invalid_argument("dataset (X, y) must be of type Integer"); throw invalid_argument("dataset (X, y) must be of type Integer");
} }
if (n != features.size()) { if (n != features.size()) {
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features"); throw invalid_argument("Classifier: X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
} }
if (states.find(className) == states.end()) { if (states.find(className) == states.end()) {
throw invalid_argument("className not found in states"); throw invalid_argument("className not found in states");

View File

@ -10,7 +10,6 @@ using namespace torch;
namespace bayesnet { namespace bayesnet {
class Classifier : public BaseClassifier { class Classifier : public BaseClassifier {
private: private:
void buildDataset(torch::Tensor& y);
Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights); Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights);
protected: protected:
bool fitted; bool fitted;
@ -26,6 +25,7 @@ namespace bayesnet {
virtual void buildModel(const torch::Tensor& weights) = 0; virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters); void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
void buildDataset(torch::Tensor& y);
public: public:
Classifier(Network model); Classifier(Network model);
virtual ~Classifier() = default; virtual ~Classifier() = default;