Begin using validation as finish condition
This commit is contained in:
parent
5a7c8f1818
commit
d908f389f5
@ -2,6 +2,7 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
|
#include "Folding.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
BoostAODE::BoostAODE() : Ensemble() {}
|
BoostAODE::BoostAODE() : Ensemble() {}
|
||||||
@ -24,15 +25,35 @@ namespace bayesnet {
|
|||||||
ascending = hyperparameters["ascending"];
|
ascending = hyperparameters["ascending"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void BoostAODE::validationInit()
|
||||||
|
{
|
||||||
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
|
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||||
|
// save input dataset
|
||||||
|
dataset_ = torch::clone(dataset);
|
||||||
|
auto [train, test] = fold.getFold(0);
|
||||||
|
auto train_t = torch::tensor(train);
|
||||||
|
auto test_t = torch::tensor(test);
|
||||||
|
// Get train and validation sets
|
||||||
|
X_train = dataset.index({ "...", train_t });
|
||||||
|
y_train = dataset.index({ -1, train_t });
|
||||||
|
X_test = dataset.index({ "...", test_t });
|
||||||
|
y_test = dataset.index({ -1, test_t });
|
||||||
|
// Build dataset with train data
|
||||||
|
dataset = X_train;
|
||||||
|
buildDataset(y_train);
|
||||||
|
m = X_train.size(1);
|
||||||
|
auto n_classes = states.at(className).size();
|
||||||
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
n_models = 0;
|
n_models = 0;
|
||||||
if (maxModels == 0)
|
if (maxModels == 0)
|
||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
|
validationInit();
|
||||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
|
||||||
bool exitCondition = false;
|
bool exitCondition = false;
|
||||||
unordered_set<int> featuresUsed;
|
unordered_set<int> featuresUsed;
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
@ -62,10 +83,10 @@ namespace bayesnet {
|
|||||||
model = std::make_unique<SPODE>(feature);
|
model = std::make_unique<SPODE>(feature);
|
||||||
n_models++;
|
n_models++;
|
||||||
model->fit(dataset, features, className, states, weights_);
|
model->fit(dataset, features, className, states, weights_);
|
||||||
auto ypred = model->predict(X_);
|
auto ypred = model->predict(X_train);
|
||||||
// Step 3.1: Compute the classifier amout of say
|
// Step 3.1: Compute the classifier amout of say
|
||||||
auto mask_wrong = ypred != y_;
|
auto mask_wrong = ypred != y_train;
|
||||||
auto mask_right = ypred == y_;
|
auto mask_right = ypred == y_train;
|
||||||
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
||||||
double epsilon_t = masked_weights.sum().item<double>();
|
double epsilon_t = masked_weights.sum().item<double>();
|
||||||
double wt = (1 - epsilon_t) / epsilon_t;
|
double wt = (1 - epsilon_t) / epsilon_t;
|
||||||
|
@ -13,6 +13,9 @@ namespace bayesnet {
|
|||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
private:
|
private:
|
||||||
|
torch::Tensor dataset_;
|
||||||
|
torch::Tensor X_train, y_train, X_test, y_test;
|
||||||
|
void validationInit();
|
||||||
bool repeatSparent = false;
|
bool repeatSparent = false;
|
||||||
int maxModels = 0;
|
int maxModels = 0;
|
||||||
bool ascending = false; //Process KBest features ascending or descending order
|
bool ascending = false; //Process KBest features ascending or descending order
|
||||||
|
@ -75,7 +75,7 @@ namespace bayesnet {
|
|||||||
throw invalid_argument("dataset (X, y) must be of type Integer");
|
throw invalid_argument("dataset (X, y) must be of type Integer");
|
||||||
}
|
}
|
||||||
if (n != features.size()) {
|
if (n != features.size()) {
|
||||||
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
throw invalid_argument("Classifier: X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
||||||
}
|
}
|
||||||
if (states.find(className) == states.end()) {
|
if (states.find(className) == states.end()) {
|
||||||
throw invalid_argument("className not found in states");
|
throw invalid_argument("className not found in states");
|
||||||
|
@ -10,7 +10,6 @@ using namespace torch;
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class Classifier : public BaseClassifier {
|
class Classifier : public BaseClassifier {
|
||||||
private:
|
private:
|
||||||
void buildDataset(torch::Tensor& y);
|
|
||||||
Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights);
|
Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||||
protected:
|
protected:
|
||||||
bool fitted;
|
bool fitted;
|
||||||
@ -26,6 +25,7 @@ namespace bayesnet {
|
|||||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
|
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
|
||||||
|
void buildDataset(torch::Tensor& y);
|
||||||
public:
|
public:
|
||||||
Classifier(Network model);
|
Classifier(Network model);
|
||||||
virtual ~Classifier() = default;
|
virtual ~Classifier() = default;
|
||||||
|
Loading…
Reference in New Issue
Block a user