Add discretiz algorithm management to b_main & Dataset

This commit is contained in:
2024-06-07 09:00:51 +02:00
parent 2202a81782
commit 5dd3deca1a
5 changed files with 28 additions and 27 deletions

View File

@@ -117,20 +117,19 @@ namespace platform {
{
auto datasets = Datasets(false, Paths::datasets()); // Never discretize here
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);
// -------------- auto [X, y] = datasets.getTensors(fileName);
// -------------- auto states = datasets.getStates(fileName);
auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName);
auto labels = datasets.getLabels(fileName);
int num_classes = states[className].size() == 0 ? labels.size() : states[className].size();
int num_classes = labels.size();
if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
}
// Prepare Result
auto partial_result = PartialResult();
auto [values, counts] = at::_unique(y);
partial_result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes);
partial_result.setHyperparameters(hyperparameters.get(fileName));
// Initialize results std::vectors
int nResults = nfolds * static_cast<int>(randomSeeds.size());
@@ -170,18 +169,10 @@ namespace platform {
// Split train - test dataset
train_timer.start();
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test);
auto X_train = X.index({ "...", train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
if (discretized) {
// compute states too
// discretizer->fit(X_train, y_train);
// X_train = discretizer->transform(X_train);
// X_test = discretizer->transform(X_test);
}
auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test);
// Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva
// una referencia al objeto dataset y trabajar directamente con él.
auto states = datasets.getStates(fileName);
if (generate_fold_files)
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
if (!quiet)