Add Convergence hyperparameter
This commit is contained in:
parent
d908f389f5
commit
506369e46b
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@ -25,9 +25,9 @@
|
|||||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||||
"args": [
|
"args": [
|
||||||
"-m",
|
"-m",
|
||||||
"AODE",
|
"BoostAODE",
|
||||||
"-p",
|
"-p",
|
||||||
"/home/rmontanana/Code/discretizbench/datasets",
|
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||||
"--stratified",
|
"--stratified",
|
||||||
"-d",
|
"-d",
|
||||||
"mfeat-morphological",
|
"mfeat-morphological",
|
||||||
@ -35,7 +35,7 @@
|
|||||||
// "--hyperparameters",
|
// "--hyperparameters",
|
||||||
// "{\"repeatSparent\": true, \"maxModels\": 12}"
|
// "{\"repeatSparent\": true, \"maxModels\": 12}"
|
||||||
],
|
],
|
||||||
"cwd": "/home/rmontanana/Code/discretizbench",
|
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "lldb",
|
"type": "lldb",
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "BayesMetrics.h"
|
#include "BayesMetrics.h"
|
||||||
#include "Colors.h"
|
#include "Colors.h"
|
||||||
#include "Folding.h"
|
#include "Folding.h"
|
||||||
|
#include <limits.h>
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
BoostAODE::BoostAODE() : Ensemble() {}
|
BoostAODE::BoostAODE() : Ensemble() {}
|
||||||
@ -13,7 +14,7 @@ namespace bayesnet {
|
|||||||
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters)
|
||||||
{
|
{
|
||||||
// Check if hyperparameters are valid
|
// Check if hyperparameters are valid
|
||||||
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending" };
|
const vector<string> validKeys = { "repeatSparent", "maxModels", "ascending", "convergence" };
|
||||||
checkHyperparameters(validKeys, hyperparameters);
|
checkHyperparameters(validKeys, hyperparameters);
|
||||||
if (hyperparameters.contains("repeatSparent")) {
|
if (hyperparameters.contains("repeatSparent")) {
|
||||||
repeatSparent = hyperparameters["repeatSparent"];
|
repeatSparent = hyperparameters["repeatSparent"];
|
||||||
@ -24,27 +25,30 @@ namespace bayesnet {
|
|||||||
if (hyperparameters.contains("ascending")) {
|
if (hyperparameters.contains("ascending")) {
|
||||||
ascending = hyperparameters["ascending"];
|
ascending = hyperparameters["ascending"];
|
||||||
}
|
}
|
||||||
|
if (hyperparameters.contains("convergence")) {
|
||||||
|
convergence = hyperparameters["convergence"];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void BoostAODE::validationInit()
|
void BoostAODE::validationInit()
|
||||||
{
|
{
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||||
// save input dataset
|
|
||||||
dataset_ = torch::clone(dataset);
|
dataset_ = torch::clone(dataset);
|
||||||
|
// save input dataset
|
||||||
auto [train, test] = fold.getFold(0);
|
auto [train, test] = fold.getFold(0);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
// Get train and validation sets
|
// Get train and validation sets
|
||||||
X_train = dataset.index({ "...", train_t });
|
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
||||||
y_train = dataset.index({ -1, train_t });
|
y_train = dataset.index({ -1, train_t });
|
||||||
X_test = dataset.index({ "...", test_t });
|
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
||||||
y_test = dataset.index({ -1, test_t });
|
y_test = dataset.index({ -1, test_t });
|
||||||
// Build dataset with train data
|
|
||||||
dataset = X_train;
|
dataset = X_train;
|
||||||
buildDataset(y_train);
|
|
||||||
m = X_train.size(1);
|
m = X_train.size(1);
|
||||||
auto n_classes = states.at(className).size();
|
auto n_classes = states.at(className).size();
|
||||||
metrics = Metrics(dataset, features, className, n_classes);
|
metrics = Metrics(dataset, features, className, n_classes);
|
||||||
|
// Build dataset with train data
|
||||||
|
buildDataset(y_train);
|
||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
@ -56,9 +60,18 @@ namespace bayesnet {
|
|||||||
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
bool exitCondition = false;
|
bool exitCondition = false;
|
||||||
unordered_set<int> featuresUsed;
|
unordered_set<int> featuresUsed;
|
||||||
|
// Variables to control the accuracy finish condition
|
||||||
|
double priorAccuracy = 0.0;
|
||||||
|
double delta = 1.0;
|
||||||
|
double threshold = 1e-4;
|
||||||
|
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold
|
||||||
|
int count = 0; // number of times the accuracy is lower than the threshold
|
||||||
|
fitted = true; // to enable predict
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == maxModels
|
// n_models == maxModels
|
||||||
|
// epsiolon sub t > 0.5 => inverse the weights policy
|
||||||
|
// validation error is not decreasing
|
||||||
while (!exitCondition) {
|
while (!exitCondition) {
|
||||||
// Step 1: Build ranking with mutual information
|
// Step 1: Build ranking with mutual information
|
||||||
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
|
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
|
||||||
@ -81,7 +94,6 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
featuresUsed.insert(feature);
|
featuresUsed.insert(feature);
|
||||||
model = std::make_unique<SPODE>(feature);
|
model = std::make_unique<SPODE>(feature);
|
||||||
n_models++;
|
|
||||||
model->fit(dataset, features, className, states, weights_);
|
model->fit(dataset, features, className, states, weights_);
|
||||||
auto ypred = model->predict(X_train);
|
auto ypred = model->predict(X_train);
|
||||||
// Step 3.1: Compute the classifier amout of say
|
// Step 3.1: Compute the classifier amout of say
|
||||||
@ -102,12 +114,22 @@ namespace bayesnet {
|
|||||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||||
models.push_back(std::move(model));
|
models.push_back(std::move(model));
|
||||||
significanceModels.push_back(alpha_t);
|
significanceModels.push_back(alpha_t);
|
||||||
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5;
|
n_models++;
|
||||||
|
auto y_val_predict = predict(X_test);
|
||||||
|
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
||||||
|
if (priorAccuracy == 0) {
|
||||||
|
priorAccuracy = accuracy;
|
||||||
|
} else {
|
||||||
|
delta = accuracy - priorAccuracy;
|
||||||
|
}
|
||||||
|
if (delta < threshold) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||||
}
|
}
|
||||||
if (featuresUsed.size() != features.size()) {
|
if (featuresUsed.size() != features.size()) {
|
||||||
status = WARNING;
|
status = WARNING;
|
||||||
}
|
}
|
||||||
weights.copy_(weights_);
|
|
||||||
}
|
}
|
||||||
vector<string> BoostAODE::graph(const string& title) const
|
vector<string> BoostAODE::graph(const string& title) const
|
||||||
{
|
{
|
||||||
|
@ -19,6 +19,7 @@ namespace bayesnet {
|
|||||||
bool repeatSparent = false;
|
bool repeatSparent = false;
|
||||||
int maxModels = 0;
|
int maxModels = 0;
|
||||||
bool ascending = false; //Process KBest features ascending or descending order
|
bool ascending = false; //Process KBest features ascending or descending order
|
||||||
|
bool convergence = false; //if true, stop when the model does not improve
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -24,7 +24,7 @@ namespace bayesnet {
|
|||||||
// i.e. votes[0] contains how much value has the value 0 of class. That value is generated by the models predictions
|
// i.e. votes[0] contains how much value has the value 0 of class. That value is generated by the models predictions
|
||||||
vector<double> votes(numClasses, 0.0);
|
vector<double> votes(numClasses, 0.0);
|
||||||
for (int j = 0; j < n_models; ++j) {
|
for (int j = 0; j < n_models; ++j) {
|
||||||
votes[y_pred_[i][j]] += significanceModels[j];
|
votes[y_pred_[i][j]] += significanceModels.at(j);
|
||||||
}
|
}
|
||||||
// argsort in descending order
|
// argsort in descending order
|
||||||
auto indices = argsort(votes);
|
auto indices = argsort(votes);
|
||||||
|
Loading…
Reference in New Issue
Block a user