Begin using validation as finish condition

This commit is contained in:
Ricardo Montañana Gómez 2023-09-06 10:51:07 +02:00
parent 5a7c8f1818
commit d908f389f5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 34 additions and 10 deletions

View File

@ -2,6 +2,7 @@
#include <set>
#include "BayesMetrics.h"
#include "Colors.h"
#include "Folding.h"
namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble() {}
@ -24,15 +25,35 @@ namespace bayesnet {
ascending = hyperparameters["ascending"];
}
}
void BoostAODE::validationInit()
{
auto y_ = dataset.index({ -1, "..." });
auto fold = platform::StratifiedKFold(5, y_, 271);
// save input dataset
dataset_ = torch::clone(dataset);
auto [train, test] = fold.getFold(0);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
// Get train and validation sets
X_train = dataset.index({ "...", train_t });
y_train = dataset.index({ -1, train_t });
X_test = dataset.index({ "...", test_t });
y_test = dataset.index({ -1, test_t });
// Build dataset with train data
dataset = X_train;
buildDataset(y_train);
m = X_train.size(1);
auto n_classes = states.at(className).size();
metrics = Metrics(dataset, features, className, n_classes);
}
void BoostAODE::trainModel(const torch::Tensor& weights)
{
models.clear();
n_models = 0;
if (maxModels == 0)
maxModels = .1 * n > 10 ? .1 * n : n;
validationInit();
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
auto y_ = dataset.index({ -1, "..." });
bool exitCondition = false;
unordered_set<int> featuresUsed;
// Step 0: Set the finish condition
@ -62,10 +83,10 @@ namespace bayesnet {
model = std::make_unique<SPODE>(feature);
n_models++;
model->fit(dataset, features, className, states, weights_);
auto ypred = model->predict(X_);
auto ypred = model->predict(X_train);
// Step 3.1: Compute the classifier amout of say
auto mask_wrong = ypred != y_;
auto mask_right = ypred == y_;
auto mask_wrong = ypred != y_train;
auto mask_right = ypred == y_train;
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
double epsilon_t = masked_weights.sum().item<double>();
double wt = (1 - epsilon_t) / epsilon_t;

View File

@ -13,9 +13,12 @@ namespace bayesnet {
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights) override;
private:
bool repeatSparent=false;
int maxModels=0;
bool ascending=false; //Process KBest features ascending or descending order
torch::Tensor dataset_;
torch::Tensor X_train, y_train, X_test, y_test;
void validationInit();
bool repeatSparent = false;
int maxModels = 0;
bool ascending = false; //Process KBest features ascending or descending order
};
}
#endif

View File

@ -75,7 +75,7 @@ namespace bayesnet {
throw invalid_argument("dataset (X, y) must be of type Integer");
}
if (n != features.size()) {
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
throw invalid_argument("Classifier: X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
}
if (states.find(className) == states.end()) {
throw invalid_argument("className not found in states");

View File

@ -10,7 +10,6 @@ using namespace torch;
namespace bayesnet {
class Classifier : public BaseClassifier {
private:
void buildDataset(torch::Tensor& y);
Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights);
protected:
bool fitted;
@ -26,6 +25,7 @@ namespace bayesnet {
virtual void buildModel(const torch::Tensor& weights) = 0;
void trainModel(const torch::Tensor& weights) override;
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
void buildDataset(torch::Tensor& y);
public:
Classifier(Network model);
virtual ~Classifier() = default;