If ! convergence don't predict test
This commit is contained in:
parent
506369e46b
commit
41257ed566
@ -32,23 +32,31 @@ namespace bayesnet {
|
||||
void BoostAODE::validationInit()
|
||||
{
|
||||
auto y_ = dataset.index({ -1, "..." });
|
||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||
dataset_ = torch::clone(dataset);
|
||||
// save input dataset
|
||||
auto [train, test] = fold.getFold(0);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
// Get train and validation sets
|
||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
||||
y_train = dataset.index({ -1, train_t });
|
||||
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
||||
y_test = dataset.index({ -1, test_t });
|
||||
dataset = X_train;
|
||||
m = X_train.size(1);
|
||||
auto n_classes = states.at(className).size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
// Build dataset with train data
|
||||
buildDataset(y_train);
|
||||
if (convergence) {
|
||||
// Prepare train & validation sets from train data
|
||||
auto fold = platform::StratifiedKFold(5, y_, 271);
|
||||
dataset_ = torch::clone(dataset);
|
||||
// save input dataset
|
||||
auto [train, test] = fold.getFold(0);
|
||||
auto train_t = torch::tensor(train);
|
||||
auto test_t = torch::tensor(test);
|
||||
// Get train and validation sets
|
||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });
|
||||
y_train = dataset.index({ -1, train_t });
|
||||
X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });
|
||||
y_test = dataset.index({ -1, test_t });
|
||||
dataset = X_train;
|
||||
m = X_train.size(1);
|
||||
auto n_classes = states.at(className).size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
// Build dataset with train data
|
||||
buildDataset(y_train);
|
||||
} else {
|
||||
// Use all data to train
|
||||
X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), "..." });
|
||||
y_train = y_;
|
||||
}
|
||||
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
@ -64,7 +72,7 @@ namespace bayesnet {
|
||||
double priorAccuracy = 0.0;
|
||||
double delta = 1.0;
|
||||
double threshold = 1e-4;
|
||||
int tolerance = convergence ? 5 : INT_MAX; // number of times the accuracy can be lower than the threshold
|
||||
int tolerance = 5; // number of times the accuracy can be lower than the threshold
|
||||
int count = 0; // number of times the accuracy is lower than the threshold
|
||||
fitted = true; // to enable predict
|
||||
// Step 0: Set the finish condition
|
||||
@ -115,15 +123,17 @@ namespace bayesnet {
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(alpha_t);
|
||||
n_models++;
|
||||
auto y_val_predict = predict(X_test);
|
||||
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
||||
if (priorAccuracy == 0) {
|
||||
priorAccuracy = accuracy;
|
||||
} else {
|
||||
delta = accuracy - priorAccuracy;
|
||||
}
|
||||
if (delta < threshold) {
|
||||
count++;
|
||||
if (convergence) {
|
||||
auto y_val_predict = predict(X_test);
|
||||
double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
|
||||
if (priorAccuracy == 0) {
|
||||
priorAccuracy = accuracy;
|
||||
} else {
|
||||
delta = accuracy - priorAccuracy;
|
||||
}
|
||||
if (delta < threshold) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
exitCondition = n_models == maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||
}
|
||||
|
@ -132,9 +132,9 @@ namespace bayesnet {
|
||||
void Network::setStates(const map<string, vector<int>>& states)
|
||||
{
|
||||
// Set states to every Node in the network
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
nodes.at(features.at(i))->setNumStates(states.at(features[i]).size());
|
||||
}
|
||||
for_each(features.begin(), features.end(), [this, &states](const string& feature) {
|
||||
nodes.at(feature)->setNumStates(states.at(feature).size());
|
||||
});
|
||||
classNumStates = nodes.at(className)->getNumStates();
|
||||
}
|
||||
// X comes in nxm, where n is the number of features and m the number of samples
|
||||
|
Loading…
Reference in New Issue
Block a user