Refactor cross_validation

This commit is contained in:
2023-07-29 16:44:07 +02:00
parent b9e76becce
commit adc0ca238f
3 changed files with 55 additions and 49 deletions

View File

@@ -126,19 +126,9 @@ int main(int argc, char** argv)
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
for (auto seed : seeds) {
cout << "(" << seed << ") " << flush;
Fold* fold;
if (stratified)
fold = new StratifiedKFold(n_folds, y, seed);
else
fold = new KFold(n_folds, samples, seed);
auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
result.setDataset(fileName);
experiment.setModelVersion("-FIXME-");
experiment.addResult(result);
delete fold;
}
auto result = experiment.cross_validation(model_name, X, y, features, className, states);
result.setDataset(fileName);
experiment.addResult(result);
cout << endl;
}
experiment.setDuration(timer.getDuration());