Add traintest split in gridsearch

This commit is contained in:
2024-06-07 11:05:59 +02:00
parent 5dd3deca1a
commit 361c51d864
8 changed files with 213 additions and 247 deletions

View File

@@ -15,10 +15,6 @@ namespace platform {
{ {
return name; return name;
} }
std::string Dataset::getClassName() const
{
return className;
}
std::vector<std::string> Dataset::getFeatures() const std::vector<std::string> Dataset::getFeatures() const
{ {
if (loaded) { if (loaded) {
@@ -43,6 +39,42 @@ namespace platform {
throw std::invalid_argument(message_dataset_not_loaded); throw std::invalid_argument(message_dataset_not_loaded);
} }
} }
std::string Dataset::getClassName() const
{
return className;
}
int Dataset::getNClasses() const
{
if (loaded) {
if (discretize) {
return states.at(className).size();
}
return *std::max_element(yv.begin(), yv.end()) + 1;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<std::string> Dataset::getLabels() const
{
// Return the labels factorization result
if (loaded) {
return labels;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<int> Dataset::getClassesCounts() const
{
if (loaded) {
std::vector<int> counts(*std::max_element(yv.begin(), yv.end()) + 1);
for (auto y : yv) {
counts[y]++;
}
return counts;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::map<std::string, std::vector<int>> Dataset::getStates() const std::map<std::string, std::vector<int>> Dataset::getStates() const
{ {
if (loaded) { if (loaded) {
@@ -70,7 +102,6 @@ namespace platform {
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors() pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
{ {
if (loaded) { if (loaded) {
buildTensors();
return { X, y }; return { X, y };
} else { } else {
throw std::invalid_argument(message_dataset_not_loaded); throw std::invalid_argument(message_dataset_not_loaded);
@@ -79,29 +110,32 @@ namespace platform {
void Dataset::load_csv() void Dataset::load_csv()
{ {
ifstream file(path + "/" + name + ".csv"); ifstream file(path + "/" + name + ".csv");
if (file.is_open()) { if (!file.is_open()) {
std::string line;
getline(file, line);
std::vector<std::string> tokens = split(line, ',');
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
if (className == "-1") {
className = tokens.back();
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(std::vector<float>());
}
while (getline(file, line)) {
tokens = split(line, ',');
for (auto i = 0; i < features.size(); ++i) {
Xv[i].push_back(stof(tokens[i]));
}
yv.push_back(stoi(tokens.back()));
}
labels.clear();
file.close();
} else {
throw std::invalid_argument("Unable to open dataset file."); throw std::invalid_argument("Unable to open dataset file.");
} }
labels.clear();
std::string line;
getline(file, line);
std::vector<std::string> tokens = split(line, ',');
features = std::vector<std::string>(tokens.begin(), tokens.end() - 1);
if (className == "-1") {
className = tokens.back();
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(std::vector<float>());
}
while (getline(file, line)) {
tokens = split(line, ',');
for (auto i = 0; i < features.size(); ++i) {
Xv[i].push_back(stof(tokens[i]));
}
auto label = trim(tokens.back());
if (find(labels.begin(), labels.end(), label) == labels.end()) {
labels.push_back(label);
}
yv.push_back(stoi(label));
}
file.close();
} }
void Dataset::computeStates() void Dataset::computeStates()
{ {
@@ -147,32 +181,35 @@ namespace platform {
void Dataset::load_rdata() void Dataset::load_rdata()
{ {
ifstream file(path + "/" + name + "_R.dat"); ifstream file(path + "/" + name + "_R.dat");
if (file.is_open()) { if (!file.is_open()) {
std::string line;
getline(file, line);
line = ArffFiles::trim(line);
std::vector<std::string> tokens = tokenize(line);
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
if (className == "-1") {
className = ArffFiles::trim(tokens.back());
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(std::vector<float>());
}
while (getline(file, line)) {
tokens = tokenize(line);
// We have to skip the first token, which is the instance number.
for (auto i = 1; i < features.size() + 1; ++i) {
const float value = stof(tokens[i]);
Xv[i - 1].push_back(value);
}
yv.push_back(stoi(tokens.back()));
}
labels.clear();
file.close();
} else {
throw std::invalid_argument("Unable to open dataset file."); throw std::invalid_argument("Unable to open dataset file.");
} }
std::string line;
labels.clear();
getline(file, line);
line = ArffFiles::trim(line);
std::vector<std::string> tokens = tokenize(line);
transform(tokens.begin(), tokens.end() - 1, back_inserter(features), [](const auto& attribute) { return ArffFiles::trim(attribute); });
if (className == "-1") {
className = ArffFiles::trim(tokens.back());
}
for (auto i = 0; i < features.size(); ++i) {
Xv.push_back(std::vector<float>());
}
while (getline(file, line)) {
tokens = tokenize(line);
// We have to skip the first token, which is the instance number.
for (auto i = 1; i < features.size() + 1; ++i) {
const float value = stof(tokens[i]);
Xv[i - 1].push_back(value);
}
auto label = trim(tokens.back());
if (find(labels.begin(), labels.end(), label) == labels.end()) {
labels.push_back(label);
}
yv.push_back(stoi(label));
}
file.close();
} }
void Dataset::load() void Dataset::load()
{ {
@@ -200,27 +237,13 @@ namespace platform {
} }
} }
} }
if (discretize) { // Build Tensors
Xd = discretizeDataset(Xv, yv); X = torch::zeros({ n_features, n_samples }, torch::kFloat32);
computeStates();
}
loaded = true;
}
void Dataset::buildTensors()
{
if (discretize) {
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kInt32);
} else {
X = torch::zeros({ static_cast<int>(n_features), static_cast<int>(n_samples) }, torch::kFloat32);
}
for (int i = 0; i < features.size(); ++i) { for (int i = 0; i < features.size(); ++i) {
if (discretize) { X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
X.index_put_({ i, "..." }, torch::tensor(Xd[i], torch::kInt32));
} else {
X.index_put_({ i, "..." }, torch::tensor(Xv[i], torch::kFloat32));
}
} }
y = torch::tensor(yv, torch::kInt32); y = torch::tensor(yv, torch::kInt32);
loaded = true;
} }
std::vector<mdlp::labels_t> Dataset::discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y) std::vector<mdlp::labels_t> Dataset::discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y)
{ {
@@ -233,9 +256,40 @@ namespace platform {
} }
return Xd; return Xd;
} }
std::pair <torch::Tensor&, torch::Tensor&> Dataset::getDiscretizedTrainTestTensors() std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> Dataset::getTrainTestTensors(std::vector<int>& train, std::vector<int>& test)
{ {
auto discretizer = Discretization::instance()->create("mdlp"); if (!loaded) {
return { X_train, X_test }; throw std::invalid_argument(message_dataset_not_loaded);
}
auto train_t = torch::tensor(train);
int samples_train = train.size();
int samples_test = test.size();
auto test_t = torch::tensor(test);
X_train = X.index({ "...", train_t });
y_train = y.index({ train_t });
X_test = X.index({ "...", test_t });
y_test = y.index({ test_t });
if (discretize) {
auto discretizer = Discretization::instance()->create(discretizer_algorithm);
auto X_train_d = torch::zeros({ n_features, samples_train }, torch::kInt32);
auto X_test_d = torch::zeros({ n_features, samples_test }, torch::kInt32);
for (int feature = 0; feature < n_features; ++feature) {
if (numericFeatures[feature]) {
auto X_train_feature = X_train.index({ feature, "..." }).to(torch::kFloat32);
auto X_test_feature = X_test.index({ feature, "..." }).to(torch::kFloat32);
discretizer->fit(X_train_feature, y_train);
auto X_train_feature_d = discretizer->transform(X_train_feature);
auto X_test_feature_d = discretizer->transform(X_test_feature);
X_train_d.index_put_({ feature, "..." }, X_train_feature_d.to(torch::kInt32));
X_test_d.index_put_({ feature, "..." }, X_test_feature_d.to(torch::kInt32));
} else {
X_train_d.index_put_({ feature, "..." }, X_train.index({ feature, "..." }).to(torch::kInt32));
X_test_d.index_put_({ feature, "..." }, X_test.index({ feature, "..." }).to(torch::kInt32));
}
}
X_train = X_train_d;
X_test = X_test_d;
}
return { X_train, X_test, y_train, y_test };
} }
} }

View File

@@ -4,27 +4,30 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <tuple>
#include <common/DiscretizationRegister.h> #include <common/DiscretizationRegister.h>
#include "Utils.h" #include "Utils.h"
#include "SourceData.h" #include "SourceData.h"
namespace platform { namespace platform {
class Dataset { class Dataset {
public: public:
Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx) : Dataset(const std::string& path, const std::string& name, const std::string& className, bool discretize, fileType_t fileType, std::vector<int> numericFeaturesIdx, std::string discretizer_algo = "none") :
path(path), name(name), className(className), discretize(discretize), path(path), name(name), className(className), discretize(discretize),
loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx) loaded(false), fileType(fileType), numericFeaturesIdx(numericFeaturesIdx), discretizer_algorithm(discretizer_algo)
{ {
}; };
explicit Dataset(const Dataset&); explicit Dataset(const Dataset&);
std::string getName() const; std::string getName() const;
std::string getClassName() const; std::string getClassName() const;
std::vector<std::string> getLabels() const { return labels; } int getNClasses() const;
std::vector<std::string> getLabels() const; // return the labels factorization result
std::vector<int> getClassesCounts() const;
std::vector<string> getFeatures() const; std::vector<string> getFeatures() const;
std::map<std::string, std::vector<int>> getStates() const; std::map<std::string, std::vector<int>> getStates() const;
std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors(); std::pair<vector<std::vector<float>>&, std::vector<int>&> getVectors();
std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(); std::pair<vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized();
std::pair<torch::Tensor&, torch::Tensor&> getDiscretizedTrainTestTensors();
std::pair<torch::Tensor&, torch::Tensor&> getTensors(); std::pair<torch::Tensor&, torch::Tensor&> getTensors();
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(std::vector<int>& train, std::vector<int>& test);
int getNFeatures() const; int getNFeatures() const;
int getNSamples() const; int getNSamples() const;
std::vector<bool>& getNumericFeatures() { return numericFeatures; } std::vector<bool>& getNumericFeatures() { return numericFeatures; }
@@ -37,6 +40,7 @@ namespace platform {
std::string className; std::string className;
int n_samples{ 0 }, n_features{ 0 }; int n_samples{ 0 }, n_features{ 0 };
std::vector<int> numericFeaturesIdx; std::vector<int> numericFeaturesIdx;
std::string discretizer_algorithm;
std::vector<bool> numericFeatures; // true if feature is numeric std::vector<bool> numericFeatures; // true if feature is numeric
std::vector<std::string> features; std::vector<std::string> features;
std::vector<std::string> labels; std::vector<std::string> labels;
@@ -44,11 +48,10 @@ namespace platform {
bool loaded; bool loaded;
bool discretize; bool discretize;
torch::Tensor X, y; torch::Tensor X, y;
torch::Tensor X_train, X_test; torch::Tensor X_train, X_test, y_train, y_test;
std::vector<std::vector<float>> Xv; std::vector<std::vector<float>> Xv;
std::vector<std::vector<int>> Xd; std::vector<std::vector<int>> Xd;
std::vector<int> yv; std::vector<int> yv;
void buildTensors();
void load_csv(); void load_csv();
void load_arff(); void load_arff();
void load_rdata(); void load_rdata();

View File

@@ -54,7 +54,7 @@ namespace platform {
throw std::invalid_argument("Invalid catalog file format."); throw std::invalid_argument("Invalid catalog file format.");
} }
datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx); datasets[name] = make_unique<Dataset>(path, name, className, discretize, fileType, numericFeaturesIdx, discretizer_algorithm);
} }
catalog.close(); catalog.close();
} }
@@ -64,110 +64,6 @@ namespace platform {
transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; }); transform(datasets.begin(), datasets.end(), back_inserter(result), [](const auto& d) { return d.first; });
return result; return result;
} }
std::vector<std::string> Datasets::getFeatures(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getFeatures();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<std::string> Datasets::getLabels(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getLabels();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
map<std::string, std::vector<int>> Datasets::getStates(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getStates();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
void Datasets::loadDataset(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return;
} else {
datasets.at(name)->load();
}
}
std::string Datasets::getClassName(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getClassName();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Datasets::getNSamples(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getNSamples();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
int Datasets::getNClasses(const std::string& name)
{
if (datasets.at(name)->isLoaded()) {
auto className = datasets.at(name)->getClassName();
if (discretize) {
auto states = getStates(name);
return states.at(className).size();
}
auto [Xv, yv] = getVectors(name);
return *std::max_element(yv.begin(), yv.end()) + 1;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<bool>& Datasets::getNumericFeatures(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
return datasets.at(name)->getNumericFeatures();
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
std::vector<int> Datasets::getClassesCounts(const std::string& name) const
{
if (datasets.at(name)->isLoaded()) {
auto [Xv, yv] = datasets.at(name)->getVectors();
std::vector<int> counts(*std::max_element(yv.begin(), yv.end()) + 1);
for (auto y : yv) {
counts[y]++;
}
return counts;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<std::vector<std::vector<float>>&, std::vector<int>&> Datasets::getVectors(const std::string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getVectors();
}
pair<std::vector<std::vector<int>>&, std::vector<int>&> Datasets::getVectorsDiscretized(const std::string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getVectorsDiscretized();
}
pair<torch::Tensor&, torch::Tensor&> Datasets::getTensors(const std::string& name)
{
if (!datasets[name]->isLoaded()) {
datasets[name]->load();
}
return datasets[name]->getTensors();
}
bool Datasets::isDataset(const std::string& name) const bool Datasets::isDataset(const std::string& name) const
{ {
return datasets.find(name) != datasets.end(); return datasets.find(name) != datasets.end();

View File

@@ -4,34 +4,23 @@
namespace platform { namespace platform {
class Datasets { class Datasets {
public: public:
explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algo = "none") : discretize(discretize), sfileType(sfileType), discretizer_algo(discretizer_algo) explicit Datasets(bool discretize, std::string sfileType, std::string discretizer_algorithm = "none") :
discretize(discretize), sfileType(sfileType), discretizer_algorithm(discretizer_algorithm)
{ {
if (discretizer_algo == "none" && discretize) { if (discretizer_algorithm == "none" && discretize) {
throw std::runtime_error("Can't discretize without discretization algorithm"); throw std::runtime_error("Can't discretize without discretization algorithm");
} }
load(); load();
}; };
std::vector<std::string> getNames(); std::vector<std::string> getNames();
std::vector<std::string> getFeatures(const std::string& name) const;
int getNSamples(const std::string& name) const;
std::vector<std::string> getLabels(const std::string& name) const;
std::string getClassName(const std::string& name) const;
int getNClasses(const std::string& name);
std::vector<bool>& getNumericFeatures(const std::string& name) const;
std::vector<int> getClassesCounts(const std::string& name) const;
std::map<std::string, std::vector<int>> getStates(const std::string& name) const;
std::pair<std::vector<std::vector<float>>&, std::vector<int>&> getVectors(const std::string& name);
std::pair<std::vector<std::vector<int>>&, std::vector<int>&> getVectorsDiscretized(const std::string& name);
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> getTrainTestTensors(const std::vector<int>& train_idx, const std::vector<int>& test_idx);
bool isDataset(const std::string& name) const; bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const; Dataset& getDataset(const std::string& name) const { return *datasets.at(name); }
std::string toString() const; std::string toString() const;
private: private:
std::string path; std::string path;
fileType_t fileType; fileType_t fileType;
std::string sfileType; std::string sfileType;
std::string discretizer_algo; std::string discretizer_algorithm;
std::map<std::string, std::unique_ptr<Dataset>> datasets; std::map<std::string, std::unique_ptr<Dataset>> datasets;
bool discretize; bool discretize;
void load(); // Loads the list of datasets void load(); // Loads the list of datasets

View File

@@ -118,17 +118,18 @@ namespace platform {
json task = tasks[n_task]; json task = tasks[n_task];
auto model = config.model; auto model = config.model;
auto grid = GridData(Paths::grid_input(model)); auto grid = GridData(Paths::grid_input(model));
auto dataset = task["dataset"].get<std::string>(); auto dataset_name = task["dataset"].get<std::string>();
auto idx_dataset = task["idx_dataset"].get<int>(); auto idx_dataset = task["idx_dataset"].get<int>();
auto seed = task["seed"].get<int>(); auto seed = task["seed"].get<int>();
auto n_fold = task["fold"].get<int>(); auto n_fold = task["fold"].get<int>();
bool stratified = config.stratified; bool stratified = config.stratified;
// Generate the hyperparamters combinations // Generate the hyperparamters combinations
auto combinations = grid.getGrid(dataset); auto& dataset = datasets.getDataset(dataset_name);
auto [X, y] = datasets.getTensors(dataset); auto combinations = grid.getGrid(dataset_name);
auto states = datasets.getStates(dataset); auto [X, y] = dataset.getTensors();
auto features = datasets.getFeatures(dataset); auto states = dataset.getStates();
auto className = datasets.getClassName(dataset); auto features = dataset.getFeatures();
auto className = dataset.getClassName();
// //
// Start working on task // Start working on task
// //
@@ -138,12 +139,7 @@ namespace platform {
else else
fold = new folding::KFold(config.n_folds, y.size(0), seed); fold = new folding::KFold(config.n_folds, y.size(0), seed);
auto [train, test] = fold->getFold(n_fold); auto [train, test] = fold->getFold(n_fold);
auto train_t = torch::tensor(train); auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
auto test_t = torch::tensor(test);
auto X_train = X.index({ "...", train_t });
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
double best_fold_score = 0.0; double best_fold_score = 0.0;
int best_idx_combination = -1; int best_idx_combination = -1;
json best_fold_hyper; json best_fold_hyper;
@@ -168,8 +164,8 @@ namespace platform {
// Build Classifier with selected hyperparameters // Build Classifier with selected hyperparameters
auto clf = Models::instance()->create(config.model); auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters(); auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, dataset); hyperparameters.check(valid, dataset_name);
clf->setHyperparameters(hyperparameters.get(dataset)); clf->setHyperparameters(hyperparameters.get(dataset_name));
// Train model // Train model
clf->fit(X_nested_train, y_nested_train, features, className, states); clf->fit(X_nested_train, y_nested_train, features, className, states);
// Test model // Test model
@@ -188,7 +184,7 @@ namespace platform {
auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper); auto hyperparameters = platform::HyperParameters(datasets.getNames(), best_fold_hyper);
auto clf = Models::instance()->create(config.model); auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters(); auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, dataset); hyperparameters.check(valid, dataset_name);
clf->setHyperparameters(best_fold_hyper); clf->setHyperparameters(best_fold_hyper);
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
best_fold_score = clf->score(X_test, y_test); best_fold_score = clf->score(X_test, y_test);

View File

@@ -115,23 +115,31 @@ namespace platform {
} }
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files) void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
{ {
//
// Load dataset and prepare data
//
auto datasets = Datasets(false, Paths::datasets()); // Never discretize here auto datasets = Datasets(false, Paths::datasets()); // Never discretize here
// Get dataset auto& dataset = datasets.getDataset(fileName);
// -------------- auto [X, y] = datasets.getTensors(fileName); dataset.load();
// -------------- auto states = datasets.getStates(fileName); auto [X, y] = dataset.getTensors(); // Only need y for folding
auto features = datasets.getFeatures(fileName); auto features = dataset.getFeatures();
auto samples = datasets.getNSamples(fileName); auto n_features = dataset.getNFeatures();
auto className = datasets.getClassName(fileName); auto n_samples = dataset.getNSamples();
auto labels = datasets.getLabels(fileName); auto className = dataset.getClassName();
int num_classes = labels.size(); auto labels = dataset.getLabels();
int num_classes = dataset.getNClasses();
if (!quiet) { if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; std::cout << " " << setw(5) << n_samples << " " << setw(5) << n_features << flush;
} }
//
// Prepare Result // Prepare Result
//
auto partial_result = PartialResult(); auto partial_result = PartialResult();
partial_result.setSamples(samples).setFeatures(features.size()).setClasses(num_classes); partial_result.setSamples(n_samples).setFeatures(n_features).setClasses(num_classes);
partial_result.setHyperparameters(hyperparameters.get(fileName)); partial_result.setHyperparameters(hyperparameters.get(fileName));
//
// Initialize results std::vectors // Initialize results std::vectors
//
int nResults = nfolds * static_cast<int>(randomSeeds.size()); int nResults = nfolds * static_cast<int>(randomSeeds.size());
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64); auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
@@ -146,6 +154,9 @@ namespace platform {
Timer train_timer, test_timer; Timer train_timer, test_timer;
int item = 0; int item = 0;
bool first_seed = true; bool first_seed = true;
//
// Loop over random seeds
//
for (auto seed : randomSeeds) { for (auto seed : randomSeeds) {
if (!quiet) { if (!quiet) {
string prefix = " "; string prefix = " ";
@@ -159,25 +170,30 @@ namespace platform {
if (stratified) if (stratified)
fold = new folding::StratifiedKFold(nfolds, y, seed); fold = new folding::StratifiedKFold(nfolds, y, seed);
else else
fold = new folding::KFold(nfolds, y.size(0), seed); fold = new folding::KFold(nfolds, n_samples, seed);
//
// Loop over folds
//
for (int nfold = 0; nfold < nfolds; nfold++) { for (int nfold = 0; nfold < nfolds; nfold++) {
auto clf = Models::instance()->create(result.getModel()); auto clf = Models::instance()->create(result.getModel());
setModelVersion(clf->getVersion()); setModelVersion(clf->getVersion());
auto valid = clf->getValidHyperparameters(); auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName); hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName)); clf->setHyperparameters(hyperparameters.get(fileName));
//
// Split train - test dataset // Split train - test dataset
//
train_timer.start(); train_timer.start();
auto [train, test] = fold->getFold(nfold); auto [train, test] = fold->getFold(nfold);
auto [X_train, X_test, y_train, y_test] = datasets.getTrainTestTensors(fileName, train, test); auto [X_train, X_test, y_train, y_test] = dataset.getTrainTestTensors(train, test);
// Posibilidad de quitar todos los métodos de datasets y dejar un sólo de getDataset que devuelva auto states = dataset.getStates();
// una referencia al objeto dataset y trabajar directamente con él.
auto states = datasets.getStates(fileName);
if (generate_fold_files) if (generate_fold_files)
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test); generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
if (!quiet) if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a"); showProgress(nfold + 1, getColor(clf->getStatus()), "a");
//
// Train model // Train model
//
clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
if (!quiet) if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b"); showProgress(nfold + 1, getColor(clf->getStatus()), "b");
@@ -189,14 +205,18 @@ namespace platform {
num_states[item] = clf->getNumberOfStates(); num_states[item] = clf->getNumberOfStates();
train_time[item] = train_timer.getDuration(); train_time[item] = train_timer.getDuration();
double accuracy_train_value = 0.0; double accuracy_train_value = 0.0;
//
// Score train // Score train
//
if (!no_train_score) { if (!no_train_score) {
auto y_predict = clf->predict(X_train); auto y_predict = clf->predict(X_train);
Scores scores(y_train, y_predict, num_classes, labels); Scores scores(y_train, y_predict, num_classes, labels);
accuracy_train_value = scores.accuracy(); accuracy_train_value = scores.accuracy();
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
} }
//
// Test model // Test model
//
if (!quiet) if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c"); showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start(); test_timer.start();
@@ -209,7 +229,9 @@ namespace platform {
confusion_matrices.push_back(scores.get_confusion_matrix_json(true)); confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
if (!quiet) if (!quiet)
std::cout << "\b\b\b, " << flush; std::cout << "\b\b\b, " << flush;
//
// Store results and times in std::vector // Store results and times in std::vector
//
partial_result.addScoreTrain(accuracy_train_value); partial_result.addScoreTrain(accuracy_train_value);
partial_result.addScoreTest(accuracy_test_value); partial_result.addScoreTest(accuracy_test_value);
partial_result.addTimeTrain(train_time[item].item<double>()); partial_result.addTimeTrain(train_time[item].item<double>());
@@ -220,6 +242,9 @@ namespace platform {
std::cout << "end. " << flush; std::cout << "end. " << flush;
delete fold; delete fold;
} }
//
// Store result totals in Result
//
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>()); partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>()); partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>()); partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());

View File

@@ -42,35 +42,37 @@ namespace platform {
sline += "\n"; sline += "\n";
header.push_back(sline); header.push_back(sline);
int num = 0; int num = 0;
for (const auto& dataset : datasets.getNames()) { for (const auto& dataset_name : datasets.getNames()) {
std::stringstream line; std::stringstream line;
line.imbue(loc); line.imbue(loc);
auto color = num % 2 ? Colors::CYAN() : Colors::BLUE(); auto color = num % 2 ? Colors::CYAN() : Colors::BLUE();
line << color << setw(3) << right << num++ << " "; line << color << setw(3) << right << num++ << " ";
line << setw(maxName) << left << dataset << " "; line << setw(maxName) << left << dataset_name << " ";
datasets.loadDataset(dataset); auto& dataset = datasets.getDataset(dataset_name);
auto nSamples = datasets.getNSamples(dataset); dataset.load();
auto nSamples = dataset.getNSamples();
line << setw(6) << right << nSamples << " "; line << setw(6) << right << nSamples << " ";
auto nFeatures = datasets.getFeatures(dataset).size(); auto nFeatures = dataset.getFeatures().size();
line << setw(5) << right << nFeatures << " "; line << setw(5) << right << nFeatures << " ";
auto numericFeatures = datasets.getNumericFeatures(dataset); auto numericFeatures = dataset.getNumericFeatures();
auto num = std::count(numericFeatures.begin(), numericFeatures.end(), true); auto num = std::count(numericFeatures.begin(), numericFeatures.end(), true);
line << setw(5) << right << num << " "; line << setw(5) << right << num << " ";
line << setw(3) << right << datasets.getNClasses(dataset) << " "; auto nClasses = dataset.getNClasses();
line << setw(3) << right << nClasses << " ";
std::string sep = ""; std::string sep = "";
oss.str(""); oss.str("");
for (auto number : datasets.getClassesCounts(dataset)) { for (auto number : dataset.getClassesCounts()) {
oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")"; oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")";
sep = " / "; sep = " / ";
} }
split_lines(maxName, line.str(), oss.str()); split_lines(maxName, line.str(), oss.str());
// Store data for Excel report // Store data for Excel report
data[dataset] = json::object(); data[dataset_name] = json::object();
data[dataset]["samples"] = nSamples; data[dataset_name]["samples"] = nSamples;
data[dataset]["features"] = datasets.getFeatures(dataset).size(); data[dataset_name]["features"] = nFeatures;
data[dataset]["numericFeatures"] = num; data[dataset_name]["numericFeatures"] = num;
data[dataset]["classes"] = datasets.getNClasses(dataset); data[dataset_name]["classes"] = nClasses;
data[dataset]["balance"] = oss.str(); data[dataset_name]["balance"] = oss.str();
} }
} }
} }

View File

@@ -61,12 +61,13 @@ namespace platform {
} }
} else { } else {
if (data["score_name"].get<std::string>() == "accuracy") { if (data["score_name"].get<std::string>() == "accuracy") {
auto dt = Datasets(false, Paths::datasets()); auto datasets = Datasets(false, Paths::datasets());
dt.loadDataset(dataset); auto& dt = datasets.getDataset(dataset);
auto numClasses = dt.getNClasses(dataset); dt.load();
auto numClasses = dt.getNClasses();
if (numClasses == 2) { if (numClasses == 2) {
std::vector<int> distribution = dt.getClassesCounts(dataset); std::vector<int> distribution = dt.getClassesCounts();
double nSamples = dt.getNSamples(dataset); double nSamples = dt.getNSamples();
std::vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end()); std::vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin); double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) { if (mark > 1) {