Begin using validation as finish condition
This commit is contained in:
parent
5a7c8f1818
commit
d908f389f5
@ -2,6 +2,7 @@
|
||||
#include <set>
|
||||
#include "BayesMetrics.h"
|
||||
#include "Colors.h"
|
||||
#include "Folding.h"
|
||||
|
||||
namespace bayesnet {
|
||||
BoostAODE::BoostAODE() : Ensemble() {}
|
||||
@ -24,15 +25,35 @@ namespace bayesnet {
|
||||
ascending = hyperparameters["ascending"];
|
||||
}
|
||||
}
|
||||
void BoostAODE::validationInit()
|
||||
{
|
||||
auto y_ = dataset.index({ -1, "..." });
|
||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||
// save input dataset
|
||||
dataset_ = torch::clone(dataset);
|
||||
auto [train, test] = fold.getFold(0);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
// Get train and validation sets
|
||||
X_train = dataset.index({ "...", train_t });
|
||||
y_train = dataset.index({ -1, train_t });
|
||||
X_test = dataset.index({ "...", test_t });
|
||||
y_test = dataset.index({ -1, test_t });
|
||||
// Build dataset with train data
|
||||
dataset = X_train;
|
||||
buildDataset(y_train);
|
||||
m = X_train.size(1);
|
||||
auto n_classes = states.at(className).size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
n_models = 0;
|
||||
if (maxModels == 0)
|
||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||
validationInit();
|
||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
auto X_ = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||
auto y_ = dataset.index({ -1, "..." });
|
||||
bool exitCondition = false;
|
||||
unordered_set<int> featuresUsed;
|
||||
// Step 0: Set the finish condition
|
||||
@ -62,10 +83,10 @@ namespace bayesnet {
|
||||
model = std::make_unique<SPODE>(feature);
|
||||
n_models++;
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
auto ypred = model->predict(X_);
|
||||
auto ypred = model->predict(X_train);
|
||||
// Step 3.1: Compute the classifier amout of say
|
||||
auto mask_wrong = ypred != y_;
|
||||
auto mask_right = ypred == y_;
|
||||
auto mask_wrong = ypred != y_train;
|
||||
auto mask_right = ypred == y_train;
|
||||
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
||||
double epsilon_t = masked_weights.sum().item<double>();
|
||||
double wt = (1 - epsilon_t) / epsilon_t;
|
||||
|
@ -13,9 +13,12 @@ namespace bayesnet {
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
bool repeatSparent=false;
|
||||
int maxModels=0;
|
||||
bool ascending=false; //Process KBest features ascending or descending order
|
||||
torch::Tensor dataset_;
|
||||
torch::Tensor X_train, y_train, X_test, y_test;
|
||||
void validationInit();
|
||||
bool repeatSparent = false;
|
||||
int maxModels = 0;
|
||||
bool ascending = false; //Process KBest features ascending or descending order
|
||||
};
|
||||
}
|
||||
#endif
|
@ -75,7 +75,7 @@ namespace bayesnet {
|
||||
throw invalid_argument("dataset (X, y) must be of type Integer");
|
||||
}
|
||||
if (n != features.size()) {
|
||||
throw invalid_argument("X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
||||
throw invalid_argument("Classifier: X " + to_string(n) + " and features " + to_string(features.size()) + " must have the same number of features");
|
||||
}
|
||||
if (states.find(className) == states.end()) {
|
||||
throw invalid_argument("className not found in states");
|
||||
|
@ -10,7 +10,6 @@ using namespace torch;
|
||||
namespace bayesnet {
|
||||
class Classifier : public BaseClassifier {
|
||||
private:
|
||||
void buildDataset(torch::Tensor& y);
|
||||
Classifier& build(const vector<string>& features, const string& className, map<string, vector<int>>& states, const torch::Tensor& weights);
|
||||
protected:
|
||||
bool fitted;
|
||||
@ -26,6 +25,7 @@ namespace bayesnet {
|
||||
virtual void buildModel(const torch::Tensor& weights) = 0;
|
||||
void trainModel(const torch::Tensor& weights) override;
|
||||
void checkHyperparameters(const vector<string>& validKeys, nlohmann::json& hyperparameters);
|
||||
void buildDataset(torch::Tensor& y);
|
||||
public:
|
||||
Classifier(Network model);
|
||||
virtual ~Classifier() = default;
|
||||
|
Loading…
Reference in New Issue
Block a user