If ! convergence don't predict test

This commit is contained in:
Ricardo Montañana Gómez 2023-09-10 19:50:36 +02:00
parent 506369e46b
commit 41257ed566
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 40 additions and 30 deletions

View File

@ -32,23 +32,31 @@ namespace bayesnet {
void BoostAODE::validationInit() void BoostAODE::validationInit()
{ {
auto y_ = dataset.index({ -1, "..." }); auto y_ = dataset.index({ -1, "..." });
auto fold = platform::StratifiedKFold(5, y_, 271); if (convergence) {
dataset_ = torch::clone(dataset); // Prepare train & validation sets from train data
// save input dataset auto fold = platform::StratifiedKFold(5, y_, 271);
auto [train, test] = fold.getFold(0); dataset_ = torch::clone(dataset);
auto train_t = torch::tensor(train); // save input dataset
auto test_t = torch::tensor(test); auto [train, test] = fold.getFold(0);
// Get train and validation sets auto train_t = torch::tensor(train);
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t }); auto test_t = torch::tensor(test);
y_train = dataset.index({ -1, train_t }); // Get train and validation sets
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t }); X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
y_test = dataset.index({ -1, test_t }); y_train = dataset.index({ -1, train_t });
dataset = X_train; X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
m = X_train.size(1); y_test = dataset.index({ -1, test_t });
auto n_classes = states.at(className).size(); dataset = X_train;
metrics = Metrics(dataset, features, className, n_classes); m = X_train.size(1);
// Build dataset with train data auto n_classes = states.at(className).size();
buildDataset(y_train); metrics = Metrics(dataset, features, className, n_classes);
// Build dataset with train data
buildDataset(y_train);
} else {
// Use all data to train
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_;
}
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
@ -64,7 +72,7 @@ namespace bayesnet {
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
double threshold = 1e-4; double threshold = 1e-4;
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold int tolerance = 5; // number of times the accuracy can be lower than the threshold
int count = 0; // number of times the accuracy is lower than the threshold int count = 0; // number of times the accuracy is lower than the threshold
fitted = true; // to enable predict fitted = true; // to enable predict
// Step 0: Set the finish condition // Step 0: Set the finish condition
@ -115,15 +123,17 @@ namespace bayesnet {
models.push_back(std::move(model)); models.push_back(std::move(model));
significanceModels.push_back(alpha_t); significanceModels.push_back(alpha_t);
n_models++; n_models++;
auto y_val_predict = predict(X_test); if (convergence) {
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0); auto y_val_predict = predict(X_test);
if (priorAccuracy == 0) { double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
priorAccuracy = accuracy; if (priorAccuracy == 0) {
} else { priorAccuracy = accuracy;
delta = accuracy - priorAccuracy; } else {
} delta = accuracy - priorAccuracy;
if (delta < threshold) { }
count++; if (delta < threshold) {
count++;
}
} }
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
} }

View File

@ -132,9 +132,9 @@ namespace bayesnet {
void Network::setStates(const map<string, vector<int>>& states) void Network::setStates(const map<string, vector<int>>& states)
{ {
// Set states to every Node in the network // Set states to every Node in the network
for (int i = 0; i < features.size(); ++i) { for_each(features.begin(), features.end(), [this, &states](const string& feature) {
nodes.at(features.at(i))->setNumStates(states.at(features[i]).size()); nodes.at(feature)->setNumStates(states.at(feature).size());
} });
classNumStates = nodes.at(className)->getNumStates(); classNumStates = nodes.at(className)->getNumStates();
} }
// X comes in nxm, where n is the number of features and m the number of samples // X comes in nxm, where n is the number of features and m the number of samples