Add traintest split in gridsearch
This commit is contained in:
@@ -115,23 +115,31 @@ namespace platform {
|
||||
}
|
||||
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
|
||||
{
|
||||
//
|
||||
// Load dataset and prepare data
|
||||
//
|
||||
auto datasets = Datasets(false, Paths::datasets()); // Never discretize here
|
||||
// Get dataset
|
||||
// -------------- auto [X, y] = datasets.getTensors(fileName);
|
||||
// -------------- auto states = datasets.getStates(fileName);
|
||||
auto features = datasets.getFeatures(fileName);
|
||||
auto samples = datasets.getNSamples(fileName);
|
||||
auto className = datasets.getClassName(fileName);
|
||||
auto labels = datasets.getLabels(fileName);
|
||||
int num_classes = labels.size();
|
||||
auto& dataset = datasets.getDataset(fileName);
|
||||
dataset.load();
|
||||
auto [X, y] = dataset.getTensors(); // Only need y for folding
|
||||
auto features = dataset.getFeatures();
|
||||
auto n_features = dataset.getNFeatures();
|
||||
auto n_samples = dataset.getNSamples();
|
||||
auto className = dataset.getClassName();
|
||||
auto labels = dataset.getLabels();
|
||||
int num_classes = dataset.getNClasses();
|
||||
if (!quiet) {
|
||||
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
||||
std::cout << " " << setw(5) << n_samples << " " << setw(5) << n_features << flush;
|
||||
}
|
||||
//
|
||||
// Prepare Result
|
||||
//
|
||||
auto partial_result = PartialResult();
|
||||
partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes);
|
||||
partial_result.setSamples(n_samples).setFeatures(n_features).setClasses(num_classes);
|
||||
partial_result.setHyperparameters(hyperparameters.get(fileName));
|
||||
//
|
||||
// Initialize results std::vectors
|
||||
//
|
||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
@@ -146,6 +154,9 @@ namespace platform {
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
bool first_seed = true;
|
||||
//
|
||||
// Loop over random seeds
|
||||
//
|
||||
for (auto seed : randomSeeds) {
|
||||
if (!quiet) {
|
||||
string prefix = " ";
|
||||
@@ -159,25 +170,30 @@ namespace platform {
|
||||
if (stratified)
|
||||
fold = new folding::StratifiedKFold(nfolds, y, seed);
|
||||
else
|
||||
fold = new folding::KFold(nfolds, y.size(0), seed);
|
||||
fold = new folding::KFold(nfolds, n_samples, seed);
|
||||
//
|
||||
// Loop over folds
|
||||
//
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
auto clf = Models::instance()->create(result.getModel());
|
||||
setModelVersion(clf->getVersion());
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, fileName);
|
||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||
//
|
||||
// Split train - test dataset
|
||||
//
|
||||
train_timer.start();
|
||||
auto [train, test] = fold->getFold(nfold);
|
||||
auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test);
|
||||
// Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva
|
||||
// una referencia al objeto dataset y trabajar directamente con él.
|
||||
auto states = datasets.getStates(fileName);
|
||||
auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
|
||||
auto states = dataset.getStates();
|
||||
if (generate_fold_files)
|
||||
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
|
||||
//
|
||||
// Train model
|
||||
//
|
||||
clf->fit(X_train, y_train, features, className, states);
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
@@ -189,14 +205,18 @@ namespace platform {
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
train_time[item] = train_timer.getDuration();
|
||||
double accuracy_train_value = 0.0;
|
||||
//
|
||||
// Score train
|
||||
//
|
||||
if (!no_train_score) {
|
||||
auto y_predict = clf->predict(X_train);
|
||||
Scores scores(y_train, y_predict, num_classes, labels);
|
||||
accuracy_train_value = scores.accuracy();
|
||||
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||
}
|
||||
//
|
||||
// Test model
|
||||
//
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
test_timer.start();
|
||||
@@ -209,7 +229,9 @@ namespace platform {
|
||||
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
|
||||
if (!quiet)
|
||||
std::cout << "\b\b\b, " << flush;
|
||||
//
|
||||
// Store results and times in std::vector
|
||||
//
|
||||
partial_result.addScoreTrain(accuracy_train_value);
|
||||
partial_result.addScoreTest(accuracy_test_value);
|
||||
partial_result.addTimeTrain(train_time[item].item<double>());
|
||||
@@ -220,6 +242,9 @@ namespace platform {
|
||||
std::cout << "end. " << flush;
|
||||
delete fold;
|
||||
}
|
||||
//
|
||||
// Store result totals in Result
|
||||
//
|
||||
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||
|
Reference in New Issue
Block a user