Add Convergence hyperparameter
This commit is contained in:
parent
d908f389f5
commit
506369e46b
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@ -25,9 +25,9 @@
|
||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||
"args": [
|
||||
"-m",
|
||||
"AODE",
|
||||
"BoostAODE",
|
||||
"-p",
|
||||
"/home/rmontanana/Code/discretizbench/datasets",
|
||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||
"--stratified",
|
||||
"-d",
|
||||
"mfeat-morphological",
|
||||
@ -35,7 +35,7 @@
|
||||
// "--hyperparameters",
|
||||
// "{\"repeatSparent\": true, \"maxModels\": 12}"
|
||||
],
|
||||
"cwd": "/home/rmontanana/Code/discretizbench",
|
||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "BayesMetrics.h"
|
||||
#include "Colors.h"
|
||||
#include "Folding.h"
|
||||
#include <limits.h>
|
||||
|
||||
namespace bayesnet {
|
||||
BoostAODE::BoostAODE() : Ensemble() {}
|
||||
@ -13,7 +14,7 @@ namespace bayesnet {
|
||||
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
|
||||
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending", "convergence" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
if (hyperparameters.contains("repeatSparent")) {
|
||||
repeatSparent = hyperparameters["repeatSparent"];
|
||||
@ -24,27 +25,30 @@ namespace bayesnet {
|
||||
if (hyperparameters.contains("ascending")) {
|
||||
ascending = hyperparameters["ascending"];
|
||||
}
|
||||
if (hyperparameters.contains("convergence")) {
|
||||
convergence = hyperparameters["convergence"];
|
||||
}
|
||||
}
|
||||
void BoostAODE::validationInit()
|
||||
{
|
||||
auto y_ = dataset.index({ -1, "..." });
|
||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||
// save input dataset
|
||||
dataset_ = torch::clone(dataset);
|
||||
// save input dataset
|
||||
auto [train, test] = fold.getFold(0);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
// Get train and validation sets
|
||||
X_train = dataset.index({ "...", train_t });
|
||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
||||
y_train = dataset.index({ -1, train_t });
|
||||
X_test = dataset.index({ "...", test_t });
|
||||
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
||||
y_test = dataset.index({ -1, test_t });
|
||||
// Build dataset with train data
|
||||
dataset = X_train;
|
||||
buildDataset(y_train);
|
||||
m = X_train.size(1);
|
||||
auto n_classes = states.at(className).size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
// Build dataset with train data
|
||||
buildDataset(y_train);
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
@ -56,9 +60,18 @@ namespace bayesnet {
|
||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
bool exitCondition = false;
|
||||
unordered_set<int> featuresUsed;
|
||||
// Variables to control the accuracy finish condition
|
||||
double priorAccuracy = 0.0;
|
||||
double delta = 1.0;
|
||||
double threshold = 1e-4;
|
||||
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold
|
||||
int count = 0; // number of times the accuracy is lower than the threshold
|
||||
fitted = true; // to enable predict
|
||||
// Step 0: Set the finish condition
|
||||
// if not repeatSparent a finish condition is run out of features
|
||||
// n_models == maxModels
|
||||
// epsiolon sub t > 0.5 => inverse the weights policy
|
||||
// validation error is not decreasing
|
||||
while (!exitCondition) {
|
||||
// Step 1: Build ranking with mutual information
|
||||
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
|
||||
@ -81,7 +94,6 @@ namespace bayesnet {
|
||||
}
|
||||
featuresUsed.insert(feature);
|
||||
model = std::make_unique<SPODE>(feature);
|
||||
n_models++;
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
auto ypred = model->predict(X_train);
|
||||
// Step 3.1: Compute the classifier amout of say
|
||||
@ -102,12 +114,22 @@ namespace bayesnet {
|
||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(alpha_t);
|
||||
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5;
|
||||
n_models++;
|
||||
auto y_val_predict = predict(X_test);
|
||||
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
||||
if (priorAccuracy == 0) {
|
||||
priorAccuracy = accuracy;
|
||||
} else {
|
||||
delta = accuracy - priorAccuracy;
|
||||
}
|
||||
if (delta < threshold) {
|
||||
count++;
|
||||
}
|
||||
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||
}
|
||||
if (featuresUsed.size() != features.size()) {
|
||||
status = WARNING;
|
||||
}
|
||||
weights.copy_(weights_);
|
||||
}
|
||||
vector<string> BoostAODE::graph(const string& title) const
|
||||
{
|
||||
|
@ -19,6 +19,7 @@ namespace bayesnet {
|
||||
bool repeatSparent = false;
|
||||
int maxModels = 0;
|
||||
bool ascending = false; //Process KBest features ascending or descending order
|
||||
bool convergence = false; //if true, stop when the model does not improve
|
||||
};
|
||||
}
|
||||
#endif
|
@ -24,7 +24,7 @@ namespace bayesnet {
|
||||
// i.e. votes[0] contains how much value has the value 0 of class. That value is generated by the models predictions
|
||||
vector<double> votes(numClasses, 0.0);
|
||||
for (int j = 0; j < n_models; ++j) {
|
||||
votes[y_pred_[i][j]] += significanceModels[j];
|
||||
votes[y_pred_[i][j]] += significanceModels.at(j);
|
||||
}
|
||||
// argsort in descending order
|
||||
auto indices = argsort(votes);
|
||||
|
Loading…
Reference in New Issue
Block a user