From 361c51d8647c805542ce8c402f9c83858c7a6c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 7 Jun 2024 11:05:59 +0200 Subject: [PATCH] Add traintest split in gridsearch --- src/common/Dataset.cpp | 196 ++++++++++++++++++++------------ src/common/Dataset.h | 15 ++- src/common/Datasets.cpp | 106 +---------------- src/common/Datasets.h | 21 +--- src/grid/GridSearch.cpp | 26 ++--- src/main/Experiment.cpp | 55 ++++++--- src/reports/DatasetsConsole.cpp | 30 ++--- src/reports/ReportBase.cpp | 11 +- 8 files changed, 213 insertions(+), 247 deletions(-) diff --git a/src/common/Dataset.cpp b/src/common/Dataset.cpp index 194e23d..7265787 100644 --- a/src/common/Dataset.cpp +++ b/src/common/Dataset.cpp @@ -15,10 +15,6 @@ namespace platform { { return name; } - std::string Dataset::getClassName() const - { - return className; - } std::vector Dataset::getFeatures() const { if (loaded) { @@ -43,6 +39,42 @@ namespace platform { throw std::invalid_argument(message_dataset_not_loaded); } } + std::string Dataset::getClassName() const + { + return className; + } + int Dataset::getNClasses() const + { + if (loaded) { + if (discretize) { + return states.at(className).size(); + } + return *std::max_element(yv.begin(), yv.end()) + 1; + } else { + throw std::invalid_argument(message_dataset_not_loaded); + } + } + std::vector Dataset::getLabels() const + { + // Return the labels factorization result + if (loaded) { + return labels; + } else { + throw std::invalid_argument(message_dataset_not_loaded); + } + } + std::vector Dataset::getClassesCounts() const + { + if (loaded) { + std::vector counts(*std::max_element(yv.begin(), yv.end()) + 1); + for (auto y : yv) { + counts[y]++; + } + return counts; + } else { + throw std::invalid_argument(message_dataset_not_loaded); + } + } std::map> Dataset::getStates() const { if (loaded) { @@ -70,7 +102,6 @@ namespace platform { pair Dataset::getTensors() { if (loaded) { - buildTensors(); return { X, y }; } else { throw std::invalid_argument(message_dataset_not_loaded); @@ -79,29 +110,32 @@ namespace platform { void Dataset::load_csv() { ifstream file(path + "/" + name + ".csv"); - if (file.is_open()) { - std::string line; - getline(file, line); - std::vector tokens = split(line, ','); - features = std::vector(tokens.begin(), tokens.end() - 1); - if (className == "-1") { - className = tokens.back(); - } - for (auto i = 0; i < features.size(); ++i) { - Xv.push_back(std::vector()); - } - while (getline(file, line)) { - tokens = split(line, ','); - for (auto i = 0; i < features.size(); ++i) { - Xv[i].push_back(stof(tokens[i])); - } - yv.push_back(stoi(tokens.back())); - } - labels.clear(); - file.close(); - } else { + if (!file.is_open()) { throw std::invalid_argument("Unable to open dataset file."); } + labels.clear(); + std::string line; + getline(file, line); + std::vector tokens = split(line, ','); + features = std::vector(tokens.begin(), tokens.end() - 1); + if (className == "-1") { + className = tokens.back(); + } + for (auto i = 0; i < features.size(); ++i) { + Xv.push_back(std::vector()); + } + while (getline(file, line)) { + tokens = split(line, ','); + for (auto i = 0; i < features.size(); ++i) { + Xv[i].push_back(stof(tokens[i])); + } + auto label = trim(tokens.back()); + if (find(labels.begin(), labels.end(), label) == labels.end()) { + labels.push_back(label); + } + yv.push_back(stoi(label)); + } + file.close(); } void Dataset::computeStates() { @@ -147,32 +181,35 @@ namespace platform { void Dataset::load_rdata() { ifstream file(path + "/" + name + "_R.dat"); - if (file.is_open()) { - std::string line; - getline(file, line); - line = ArffFiles::trim(line); - std::vector tokens = tokenize(line); - transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); }); - if (className == "-1") { - className = ArffFiles::trim(tokens.back()); - } - for (auto i = 0; i < features.size(); ++i) { - Xv.push_back(std::vector()); - } - while (getline(file, line)) { - tokens = tokenize(line); - // We have to skip the first token, which is the instance number. - for (auto i = 1; i < features.size() + 1; ++i) { - const float value = stof(tokens[i]); - Xv[i - 1].push_back(value); - } - yv.push_back(stoi(tokens.back())); - } - labels.clear(); - file.close(); - } else { + if (!file.is_open()) { throw std::invalid_argument("Unable to open dataset file."); } + std::string line; + labels.clear(); + getline(file, line); + line = ArffFiles::trim(line); + std::vector tokens = tokenize(line); + transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); }); + if (className == "-1") { + className = ArffFiles::trim(tokens.back()); + } + for (auto i = 0; i < features.size(); ++i) { + Xv.push_back(std::vector()); + } + while (getline(file, line)) { + tokens = tokenize(line); + // We have to skip the first token, which is the instance number. + for (auto i = 1; i < features.size() + 1; ++i) { + const float value = stof(tokens[i]); + Xv[i - 1].push_back(value); + } + auto label = trim(tokens.back()); + if (find(labels.begin(), labels.end(), label) == labels.end()) { + labels.push_back(label); + } + yv.push_back(stoi(label)); + } + file.close(); } void Dataset::load() { @@ -200,27 +237,13 @@ namespace platform { } } } - if (discretize) { - Xd = discretizeDataset(Xv, yv); - computeStates(); - } - loaded = true; - } - void Dataset::buildTensors() - { - if (discretize) { - X = torch::zeros({ static_cast(n_features), static_cast(n_samples) }, torch::kInt32); - } else { - X = torch::zeros({ static_cast(n_features), static_cast(n_samples) }, torch::kFloat32); - } + // Build Tensors + X = torch::zeros({ n_features, n_samples }, torch::kFloat32); for (int i = 0; i < features.size(); ++i) { - if (discretize) { - X.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32)); - } else { - X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32)); - } + X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32)); } y = torch::tensor(yv, torch::kInt32); + loaded = true; } std::vector Dataset::discretizeDataset(std::vector& X, mdlp::labels_t& y) { @@ -233,9 +256,40 @@ namespace platform { } return Xd; } - std::pair Dataset::getDiscretizedTrainTestTensors() + std::tuple Dataset::getTrainTestTensors(std::vector& train, std::vector& test) { - auto discretizer = Discretization::instance()->create("mdlp"); - return { X_train, X_test }; + if (!loaded) { + throw std::invalid_argument(message_dataset_not_loaded); + } + auto train_t = torch::tensor(train); + int samples_train = train.size(); + int samples_test = test.size(); + auto test_t = torch::tensor(test); + X_train = X.index({ "...", train_t }); + y_train = y.index({ train_t }); + X_test = X.index({ "...", test_t }); + y_test = y.index({ test_t }); + if (discretize) { + auto discretizer = Discretization::instance()->create(discretizer_algorithm); + auto X_train_d = torch::zeros({ n_features, samples_train }, torch::kInt32); + auto X_test_d = torch::zeros({ n_features, samples_test }, torch::kInt32); + for (int feature = 0; feature < n_features; ++feature) { + if (numericFeatures[feature]) { + auto X_train_feature = X_train.index({ feature, "..." }).to(torch::kFloat32); + auto X_test_feature = X_test.index({ feature, "..." }).to(torch::kFloat32); + discretizer->fit(X_train_feature, y_train); + auto X_train_feature_d = discretizer->transform(X_train_feature); + auto X_test_feature_d = discretizer->transform(X_test_feature); + X_train_d.index_put_({ feature, "..." }, X_train_feature_d.to(torch::kInt32)); + X_test_d.index_put_({ feature, "..." }, X_test_feature_d.to(torch::kInt32)); + } else { + X_train_d.index_put_({ feature, "..." }, X_train.index({ feature, "..." }).to(torch::kInt32)); + X_test_d.index_put_({ feature, "..." }, X_test.index({ feature, "..." }).to(torch::kInt32)); + } + } + X_train = X_train_d; + X_test = X_test_d; + } + return { X_train, X_test, y_train, y_test }; } } \ No newline at end of file diff --git a/src/common/Dataset.h b/src/common/Dataset.h index afd609e..3bf6fc0 100644 --- a/src/common/Dataset.h +++ b/src/common/Dataset.h @@ -4,27 +4,30 @@ #include #include #include +#include #include #include "Utils.h" #include "SourceData.h" namespace platform { class Dataset { public: - Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector numericFeaturesIdx) : + Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector numericFeaturesIdx, std::string discretizer_algo = "none") : path(path), name(name), className(className), discretize(discretize), - loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx) + loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx), discretizer_algorithm(discretizer_algo) { }; explicit Dataset(const Dataset&); std::string getName() const; std::string getClassName() const; - std::vector getLabels() const { return labels; } + int getNClasses() const; + std::vector getLabels() const; // return the labels factorization result + std::vector getClassesCounts() const; std::vector getFeatures() const; std::map> getStates() const; std::pair>&, std::vector&> getVectors(); std::pair>&, std::vector&> getVectorsDiscretized(); - std::pair getDiscretizedTrainTestTensors(); std::pair getTensors(); + std::tuple getTrainTestTensors(std::vector& train, std::vector& test); int getNFeatures() const; int getNSamples() const; std::vector& getNumericFeatures() { return numericFeatures; } @@ -37,6 +40,7 @@ namespace platform { std::string className; int n_samples{ 0 }, n_features{ 0 }; std::vector numericFeaturesIdx; + std::string discretizer_algorithm; std::vector numericFeatures; // true if feature is numeric std::vector features; std::vector labels; @@ -44,11 +48,10 @@ namespace platform { bool loaded; bool discretize; torch::Tensor X, y; - torch::Tensor X_train, X_test; + torch::Tensor X_train, X_test, y_train, y_test; std::vector> Xv; std::vector> Xd; std::vector yv; - void buildTensors(); void load_csv(); void load_arff(); void load_rdata(); diff --git a/src/common/Datasets.cpp b/src/common/Datasets.cpp index e21d741..45a73e6 100644 --- a/src/common/Datasets.cpp +++ b/src/common/Datasets.cpp @@ -54,7 +54,7 @@ namespace platform { throw std::invalid_argument("Invalid catalog file format."); } - datasets[name] = make_unique(path, name, className, discretize, fileType, numericFeaturesIdx); + datasets[name] = make_unique(path, name, className, discretize, fileType, numericFeaturesIdx, discretizer_algorithm); } catalog.close(); } @@ -64,110 +64,6 @@ namespace platform { transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; }); return result; } - std::vector Datasets::getFeatures(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getFeatures(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - std::vector Datasets::getLabels(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getLabels(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - map> Datasets::getStates(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getStates(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - void Datasets::loadDataset(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return; - } else { - datasets.at(name)->load(); - } - } - std::string Datasets::getClassName(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getClassName(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - int Datasets::getNSamples(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getNSamples(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - int Datasets::getNClasses(const std::string& name) - { - if (datasets.at(name)->isLoaded()) { - auto className = datasets.at(name)->getClassName(); - if (discretize) { - auto states = getStates(name); - return states.at(className).size(); - } - auto [Xv, yv] = getVectors(name); - return *std::max_element(yv.begin(), yv.end()) + 1; - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - std::vector& Datasets::getNumericFeatures(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - return datasets.at(name)->getNumericFeatures(); - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - std::vector Datasets::getClassesCounts(const std::string& name) const - { - if (datasets.at(name)->isLoaded()) { - auto [Xv, yv] = datasets.at(name)->getVectors(); - std::vector counts(*std::max_element(yv.begin(), yv.end()) + 1); - for (auto y : yv) { - counts[y]++; - } - return counts; - } else { - throw std::invalid_argument(message_dataset_not_loaded); - } - } - pair>&, std::vector&> Datasets::getVectors(const std::string& name) - { - if (!datasets[name]->isLoaded()) { - datasets[name]->load(); - } - return datasets[name]->getVectors(); - } - pair>&, std::vector&> Datasets::getVectorsDiscretized(const std::string& name) - { - if (!datasets[name]->isLoaded()) { - datasets[name]->load(); - } - return datasets[name]->getVectorsDiscretized(); - } - pair Datasets::getTensors(const std::string& name) - { - if (!datasets[name]->isLoaded()) { - datasets[name]->load(); - } - return datasets[name]->getTensors(); - } bool Datasets::isDataset(const std::string& name) const { return datasets.find(name) != datasets.end(); diff --git a/src/common/Datasets.h b/src/common/Datasets.h index 44c8a05..35c58e7 100644 --- a/src/common/Datasets.h +++ b/src/common/Datasets.h @@ -4,34 +4,23 @@ namespace platform { class Datasets { public: - explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algo = "none") : discretize(discretize), sfileType(sfileType), discretizer_algo(discretizer_algo) + explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algorithm = "none") : + discretize(discretize), sfileType(sfileType), discretizer_algorithm(discretizer_algorithm) { - if (discretizer_algo == "none" && discretize) { + if (discretizer_algorithm == "none" && discretize) { throw std::runtime_error("Can't discretize without discretization algorithm"); } load(); }; std::vector getNames(); - std::vector getFeatures(const std::string& name) const; - int getNSamples(const std::string& name) const; - std::vector getLabels(const std::string& name) const; - std::string getClassName(const std::string& name) const; - int getNClasses(const std::string& name); - std::vector& getNumericFeatures(const std::string& name) const; - std::vector getClassesCounts(const std::string& name) const; - std::map> getStates(const std::string& name) const; - std::pair>&, std::vector&> getVectors(const std::string& name); - std::pair>&, std::vector&> getVectorsDiscretized(const std::string& name); - std::pair getTensors(const std::string& name); - std::tuple getTrainTestTensors(const std::vector& train_idx, const std::vector& test_idx); bool isDataset(const std::string& name) const; - void loadDataset(const std::string& name) const; + Dataset& getDataset(const std::string& name) const { return *datasets.at(name); } std::string toString() const; private: std::string path; fileType_t fileType; std::string sfileType; - std::string discretizer_algo; + std::string discretizer_algorithm; std::map> datasets; bool discretize; void load(); // Loads the list of datasets diff --git a/src/grid/GridSearch.cpp b/src/grid/GridSearch.cpp index 5ea92a8..556127f 100644 --- a/src/grid/GridSearch.cpp +++ b/src/grid/GridSearch.cpp @@ -118,17 +118,18 @@ namespace platform { json task = tasks[n_task]; auto model = config.model; auto grid = GridData(Paths::grid_input(model)); - auto dataset = task["dataset"].get(); + auto dataset_name = task["dataset"].get(); auto idx_dataset = task["idx_dataset"].get(); auto seed = task["seed"].get(); auto n_fold = task["fold"].get(); bool stratified = config.stratified; // Generate the hyperparamters combinations - auto combinations = grid.getGrid(dataset); - auto [X, y] = datasets.getTensors(dataset); - auto states = datasets.getStates(dataset); - auto features = datasets.getFeatures(dataset); - auto className = datasets.getClassName(dataset); + auto& dataset = datasets.getDataset(dataset_name); + auto combinations = grid.getGrid(dataset_name); + auto [X, y] = dataset.getTensors(); + auto states = dataset.getStates(); + auto features = dataset.getFeatures(); + auto className = dataset.getClassName(); // // Start working on task // @@ -138,12 +139,7 @@ namespace platform { else fold = new folding::KFold(config.n_folds, y.size(0), seed); auto [train, test] = fold->getFold(n_fold); - auto train_t = torch::tensor(train); - auto test_t = torch::tensor(test); - auto X_train = X.index({ "...", train_t }); - auto y_train = y.index({ train_t }); - auto X_test = X.index({ "...", test_t }); - auto y_test = y.index({ test_t }); + auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test); double best_fold_score = 0.0; int best_idx_combination = -1; json best_fold_hyper; @@ -168,8 +164,8 @@ namespace platform { // Build Classifier with selected hyperparameters auto clf = Models::instance()->create(config.model); auto valid = clf->getValidHyperparameters(); - hyperparameters.check(valid, dataset); - clf->setHyperparameters(hyperparameters.get(dataset)); + hyperparameters.check(valid, dataset_name); + clf->setHyperparameters(hyperparameters.get(dataset_name)); // Train model clf->fit(X_nested_train, y_nested_train, features, className, states); // Test model @@ -188,7 +184,7 @@ namespace platform { auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper); auto clf = Models::instance()->create(config.model); auto valid = clf->getValidHyperparameters(); - hyperparameters.check(valid, dataset); + hyperparameters.check(valid, dataset_name); clf->setHyperparameters(best_fold_hyper); clf->fit(X_train, y_train, features, className, states); best_fold_score = clf->score(X_test, y_test); diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index c5fb739..c09eeae 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -115,23 +115,31 @@ namespace platform { } void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files) { + // + // Load dataset and prepare data + // auto datasets = Datasets(false, Paths::datasets()); // Never discretize here - // Get dataset - // -------------- auto [X, y] = datasets.getTensors(fileName); - // -------------- auto states = datasets.getStates(fileName); - auto features = datasets.getFeatures(fileName); - auto samples = datasets.getNSamples(fileName); - auto className = datasets.getClassName(fileName); - auto labels = datasets.getLabels(fileName); - int num_classes = labels.size(); + auto& dataset = datasets.getDataset(fileName); + dataset.load(); + auto [X, y] = dataset.getTensors(); // Only need y for folding + auto features = dataset.getFeatures(); + auto n_features = dataset.getNFeatures(); + auto n_samples = dataset.getNSamples(); + auto className = dataset.getClassName(); + auto labels = dataset.getLabels(); + int num_classes = dataset.getNClasses(); if (!quiet) { - std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; + std::cout << " " << setw(5) << n_samples << " " << setw(5) << n_features << flush; } + // // Prepare Result + // auto partial_result = PartialResult(); - partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes); + partial_result.setSamples(n_samples).setFeatures(n_features).setClasses(num_classes); partial_result.setHyperparameters(hyperparameters.get(fileName)); + // // Initialize results std::vectors + // int nResults = nfolds * static_cast(randomSeeds.size()); auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64); @@ -146,6 +154,9 @@ namespace platform { Timer train_timer, test_timer; int item = 0; bool first_seed = true; + // + // Loop over random seeds + // for (auto seed : randomSeeds) { if (!quiet) { string prefix = " "; @@ -159,25 +170,30 @@ namespace platform { if (stratified) fold = new folding::StratifiedKFold(nfolds, y, seed); else - fold = new folding::KFold(nfolds, y.size(0), seed); + fold = new folding::KFold(nfolds, n_samples, seed); + // + // Loop over folds + // for (int nfold = 0; nfold < nfolds; nfold++) { auto clf = Models::instance()->create(result.getModel()); setModelVersion(clf->getVersion()); auto valid = clf->getValidHyperparameters(); hyperparameters.check(valid, fileName); clf->setHyperparameters(hyperparameters.get(fileName)); + // // Split train - test dataset + // train_timer.start(); auto [train, test] = fold->getFold(nfold); - auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test); - // Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva - // una referencia al objeto dataset y trabajar directamente con él. - auto states = datasets.getStates(fileName); + auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test); + auto states = dataset.getStates(); if (generate_fold_files) generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "a"); + // // Train model + // clf->fit(X_train, y_train, features, className, states); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "b"); @@ -189,14 +205,18 @@ namespace platform { num_states[item] = clf->getNumberOfStates(); train_time[item] = train_timer.getDuration(); double accuracy_train_value = 0.0; + // // Score train + // if (!no_train_score) { auto y_predict = clf->predict(X_train); Scores scores(y_train, y_predict, num_classes, labels); accuracy_train_value = scores.accuracy(); confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); } + // // Test model + // if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); test_timer.start(); @@ -209,7 +229,9 @@ namespace platform { confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); if (!quiet) std::cout << "\b\b\b, " << flush; + // // Store results and times in std::vector + // partial_result.addScoreTrain(accuracy_train_value); partial_result.addScoreTest(accuracy_test_value); partial_result.addTimeTrain(train_time[item].item()); @@ -220,6 +242,9 @@ namespace platform { std::cout << "end. " << flush; delete fold; } + // + // Store result totals in Result + // partial_result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); partial_result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); partial_result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); diff --git a/src/reports/DatasetsConsole.cpp b/src/reports/DatasetsConsole.cpp index 9c0ddcc..59bd454 100644 --- a/src/reports/DatasetsConsole.cpp +++ b/src/reports/DatasetsConsole.cpp @@ -42,35 +42,37 @@ namespace platform { sline += "\n"; header.push_back(sline); int num = 0; - for (const auto& dataset : datasets.getNames()) { + for (const auto& dataset_name : datasets.getNames()) { std::stringstream line; line.imbue(loc); auto color = num % 2 ? Colors::CYAN() : Colors::BLUE(); line << color << setw(3) << right << num++ << " "; - line << setw(maxName) << left << dataset << " "; - datasets.loadDataset(dataset); - auto nSamples = datasets.getNSamples(dataset); + line << setw(maxName) << left << dataset_name << " "; + auto& dataset = datasets.getDataset(dataset_name); + dataset.load(); + auto nSamples = dataset.getNSamples(); line << setw(6) << right << nSamples << " "; - auto nFeatures = datasets.getFeatures(dataset).size(); + auto nFeatures = dataset.getFeatures().size(); line << setw(5) << right << nFeatures << " "; - auto numericFeatures = datasets.getNumericFeatures(dataset); + auto numericFeatures = dataset.getNumericFeatures(); auto num = std::count(numericFeatures.begin(), numericFeatures.end(), true); line << setw(5) << right << num << " "; - line << setw(3) << right << datasets.getNClasses(dataset) << " "; + auto nClasses = dataset.getNClasses(); + line << setw(3) << right << nClasses << " "; std::string sep = ""; oss.str(""); - for (auto number : datasets.getClassesCounts(dataset)) { + for (auto number : dataset.getClassesCounts()) { oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; sep = " / "; } split_lines(maxName, line.str(), oss.str()); // Store data for Excel report - data[dataset] = json::object(); - data[dataset]["samples"] = nSamples; - data[dataset]["features"] = datasets.getFeatures(dataset).size(); - data[dataset]["numericFeatures"] = num; - data[dataset]["classes"] = datasets.getNClasses(dataset); - data[dataset]["balance"] = oss.str(); + data[dataset_name] = json::object(); + data[dataset_name]["samples"] = nSamples; + data[dataset_name]["features"] = nFeatures; + data[dataset_name]["numericFeatures"] = num; + data[dataset_name]["classes"] = nClasses; + data[dataset_name]["balance"] = oss.str(); } } } diff --git a/src/reports/ReportBase.cpp b/src/reports/ReportBase.cpp index ef02e72..b8eceeb 100644 --- a/src/reports/ReportBase.cpp +++ b/src/reports/ReportBase.cpp @@ -61,12 +61,13 @@ namespace platform { } } else { if (data["score_name"].get() == "accuracy") { - auto dt = Datasets(false, Paths::datasets()); - dt.loadDataset(dataset); - auto numClasses = dt.getNClasses(dataset); + auto datasets = Datasets(false, Paths::datasets()); + auto& dt = datasets.getDataset(dataset); + dt.load(); + auto numClasses = dt.getNClasses(); if (numClasses == 2) { - std::vector distribution = dt.getClassesCounts(dataset); - double nSamples = dt.getNSamples(dataset); + std::vector distribution = dt.getClassesCounts(); + double nSamples = dt.getNSamples(); std::vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); double mark = *maxValue / nSamples * (1 + margin); if (mark > 1) {