Add traintest split in gridsearch

This commit is contained in:
2024-06-07 11:05:59 +02:00
parent 5dd3deca1a
commit 361c51d864
8 changed files with 213 additions and 247 deletions

View File

@@ -115,23 +115,31 @@ namespace platform {
}
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
{
//
// Load dataset and prepare data
//
auto datasets = Datasets(false, Paths::datasets()); // Never discretize here
// Get dataset
// -------------- auto [X, y] = datasets.getTensors(fileName);
// -------------- auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
auto labels = datasets.getLabels(fileName);
int num_classes = labels.size();
auto& dataset = datasets.getDataset(fileName);
dataset.load();
auto [X, y] = dataset.getTensors(); // Only need y for folding
auto features = dataset.getFeatures();
auto n_features = dataset.getNFeatures();
auto n_samples = dataset.getNSamples();
auto className = dataset.getClassName();
auto labels = dataset.getLabels();
int num_classes = dataset.getNClasses();
if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
std::cout << " " << setw(5) << n_samples << " " << setw(5) << n_features << flush;
}
//
// Prepare Result
//
auto partial_result = PartialResult();
partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes);
partial_result.setSamples(n_samples).setFeatures(n_features).setClasses(num_classes);
partial_result.setHyperparameters(hyperparameters.get(fileName));
//
// Initialize results std::vectors
//
int nResults = nfolds * static_cast<int>(randomSeeds.size());
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
@@ -146,6 +154,9 @@ namespace platform {
Timer train_timer, test_timer;
int item = 0;
bool first_seed = true;
//
// Loop over random seeds
//
for (auto seed : randomSeeds) {
if (!quiet) {
string prefix = " ";
@@ -159,25 +170,30 @@ namespace platform {
if (stratified)
fold = new folding::StratifiedKFold(nfolds, y, seed);
else
fold = new folding::KFold(nfolds, y.size(0), seed);
fold = new folding::KFold(nfolds, n_samples, seed);
//
// Loop over folds
//
for (int nfold = 0; nfold < nfolds; nfold++) {
auto clf = Models::instance()->create(result.getModel());
setModelVersion(clf->getVersion());
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
//
// Split train - test dataset
//
train_timer.start();
auto [train, test] = fold->getFold(nfold);
auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test);
// Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva
// una referencia al objeto dataset y trabajar directamente con él.
auto states = datasets.getStates(fileName);
auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
auto states = dataset.getStates();
if (generate_fold_files)
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
//
// Train model
//
clf->fit(X_train, y_train, features, className, states);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
@@ -189,14 +205,18 @@ namespace platform {
num_states[item] = clf->getNumberOfStates();
train_time[item] = train_timer.getDuration();
double accuracy_train_value = 0.0;
//
// Score train
//
if (!no_train_score) {
auto y_predict = clf->predict(X_train);
Scores scores(y_train, y_predict, num_classes, labels);
accuracy_train_value = scores.accuracy();
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
}
//
// Test model
//
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start();
@@ -209,7 +229,9 @@ namespace platform {
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
if (!quiet)
std::cout << "\b\b\b, " << flush;
//
// Store results and times in std::vector
//
partial_result.addScoreTrain(accuracy_train_value);
partial_result.addScoreTest(accuracy_test_value);
partial_result.addTimeTrain(train_time[item].item<double>());
@@ -220,6 +242,9 @@ namespace platform {
std::cout << "end. " << flush;
delete fold;
}
//
// Store result totals in Result
//
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());