If ! convergence don't predict test

This commit is contained in:
Ricardo Montañana Gómez 2023-09-10 19:50:36 +02:00
parent 506369e46b
commit 41257ed566
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 40 additions and 30 deletions

View File

@ -32,6 +32,8 @@ namespace bayesnet {
void BoostAODE::validationInit() void BoostAODE::validationInit()
{ {
auto y_ = dataset.index({ -1, "..." }); auto y_ = dataset.index({ -1, "..." });
if (convergence) {
// Prepare train & validation sets from train data
auto fold = platform::StratifiedKFold(5, y_, 271); auto fold = platform::StratifiedKFold(5, y_, 271);
dataset_ = torch::clone(dataset); dataset_ = torch::clone(dataset);
// save input dataset // save input dataset
@ -49,6 +51,12 @@ namespace bayesnet {
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
// Build dataset with train data // Build dataset with train data
buildDataset(y_train); buildDataset(y_train);
} else {
// Use all data to train
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
y_train = y_;
}
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
@ -64,7 +72,7 @@ namespace bayesnet {
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
double threshold = 1e-4; double threshold = 1e-4;
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold int tolerance = 5; // number of times the accuracy can be lower than the threshold
int count = 0; // number of times the accuracy is lower than the threshold int count = 0; // number of times the accuracy is lower than the threshold
fitted = true; // to enable predict fitted = true; // to enable predict
// Step 0: Set the finish condition // Step 0: Set the finish condition
@ -115,6 +123,7 @@ namespace bayesnet {
models.push_back(std::move(model)); models.push_back(std::move(model));
significanceModels.push_back(alpha_t); significanceModels.push_back(alpha_t);
n_models++; n_models++;
if (convergence) {
auto y_val_predict = predict(X_test); auto y_val_predict = predict(X_test);
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0); double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
if (priorAccuracy == 0) { if (priorAccuracy == 0) {
@ -125,6 +134,7 @@ namespace bayesnet {
if (delta < threshold) { if (delta < threshold) {
count++; count++;
} }
}
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
} }
if (featuresUsed.size() != features.size()) { if (featuresUsed.size() != features.size()) {

View File

@ -132,9 +132,9 @@ namespace bayesnet {
void Network::setStates(const map<string, vector<int>>& states) void Network::setStates(const map<string, vector<int>>& states)
{ {
// Set states to every Node in the network // Set states to every Node in the network
for (int i = 0; i < features.size(); ++i) { for_each(features.begin(), features.end(), [this, &states](const string& feature) {
nodes.at(features.at(i))->setNumStates(states.at(features[i]).size()); nodes.at(feature)->setNumStates(states.at(feature).size());
} });
classNumStates = nodes.at(className)->getNumStates(); classNumStates = nodes.at(className)->getNumStates();
} }
// X comes in nxm, where n is the number of features and m the number of samples // X comes in nxm, where n is the number of features and m the number of samples