Refactor cross_validation
This commit is contained in:
@@ -126,19 +126,9 @@ int main(int argc, char** argv)
|
||||
auto samples = datasets.getNSamples(fileName);
|
||||
auto className = datasets.getClassName(fileName);
|
||||
cout << " (" << setw(5) << samples << "," << setw(3) << features.size() << ") " << flush;
|
||||
for (auto seed : seeds) {
|
||||
cout << "(" << seed << ") " << flush;
|
||||
Fold* fold;
|
||||
if (stratified)
|
||||
fold = new StratifiedKFold(n_folds, y, seed);
|
||||
else
|
||||
fold = new KFold(n_folds, samples, seed);
|
||||
auto result = platform::cross_validation(fold, model_name, X, y, features, className, states);
|
||||
result.setDataset(fileName);
|
||||
experiment.setModelVersion("-FIXME-");
|
||||
experiment.addResult(result);
|
||||
delete fold;
|
||||
}
|
||||
auto result = experiment.cross_validation(model_name, X, y, features, className, states);
|
||||
result.setDataset(fileName);
|
||||
experiment.addResult(result);
|
||||
cout << endl;
|
||||
}
|
||||
experiment.setDuration(timer.getDuration());
|
||||
|
Reference in New Issue
Block a user