Complete first step

This commit is contained in:
Ricardo Montañana Gómez 2023-11-23 12:59:21 +01:00
parent bbe5302ab1
commit 8b7b59d42b
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 38 additions and 8 deletions

View File

@ -36,7 +36,7 @@ namespace platform {
return Colors::RESET(); return Colors::RESET();
} }
} }
void GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
{ {
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); auto [X, y] = datasets.getTensors(fileName);
@ -44,6 +44,8 @@ namespace platform {
auto features = datasets.getFeatures(fileName); auto features = datasets.getFeatures(fileName);
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
double totalScore = 0.0;
int numItems = 0;
for (const auto& seed : config.seeds) { for (const auto& seed : config.seeds) {
std::cout << "(" << seed << ") doing Fold: " << flush; std::cout << "(" << seed << ") doing Fold: " << flush;
Fold* fold; Fold* fold;
@ -51,8 +53,10 @@ namespace platform {
fold = new StratifiedKFold(config.n_folds, y, seed); fold = new StratifiedKFold(config.n_folds, y, seed);
else else
fold = new KFold(config.n_folds, y.size(0), seed); fold = new KFold(config.n_folds, y.size(0), seed);
double bestScore = 0.0;
for (int nfold = 0; nfold < config.n_folds; nfold++) { for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model); auto clf = Models::instance()->create(config.model);
clf->setHyperparameters(hyperparameters.get(fileName));
auto [train, test] = fold->getFold(nfold); auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train); auto train_t = torch::tensor(train);
auto test_t = torch::tensor(test); auto test_t = torch::tensor(test);
@ -60,15 +64,18 @@ namespace platform {
auto y_train = y.index({ train_t }); auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t }); auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t }); auto y_test = y.index({ test_t });
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
// Train model // Train model
// clf->fit(X_train, y_train, features, className, states); clf->fit(X_train, y_train, features, className, states);
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
totalScore += clf->score(X_test, y_test);
numItems++;
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
std::cout << "\b\b\b, \b" << flush; std::cout << "\b\b\b, \b" << flush;
} }
delete fold; delete fold;
} }
return numItems == 0 ? 0.0 : totalScore / numItems;
} }
void GridSearch::go() void GridSearch::go()
{ {
@ -83,12 +90,21 @@ namespace platform {
for (const auto& dataset : datasets.getNames()) { for (const auto& dataset : datasets.getNames()) {
std::cout << "- " << setw(20) << left << dataset << " " << right << flush; std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
int num = 0; int num = 0;
double bestScore = 0.0;
json bestHyperparameters;
for (const auto& hyperparam_line : grid.getGrid(config.model)) { for (const auto& hyperparam_line : grid.getGrid(config.model)) {
showProgressComb(++num, totalComb, Colors::CYAN()); showProgressComb(++num, totalComb, Colors::CYAN());
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
processFile(dataset, datasets, hyperparameters); double score = processFile(dataset, datasets, hyperparameters);
if (score > bestScore) {
bestScore = score;
bestHyperparameters = hyperparam_line;
}
} }
std::cout << "end." << std::endl; std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
results[dataset]["score"] = bestScore;
results[dataset]["hyperparameters"] = bestHyperparameters;
} }
// Save results // Save results
save(); save();
@ -96,7 +112,7 @@ namespace platform {
void GridSearch::save() void GridSearch::save()
{ {
std::ofstream file(config.output_file); std::ofstream file(config.output_file);
// file << results.dump(4); file << results.dump(4);
file.close(); file.close();
} }
} /* namespace platform */ } /* namespace platform */

View File

@ -2,11 +2,13 @@
#define GRIDSEARCH_H #define GRIDSEARCH_H
#include <string> #include <string>
#include <vector> #include <vector>
#include <nlohmann/json.hpp>
#include "Datasets.h" #include "Datasets.h"
#include "HyperParameters.h" #include "HyperParameters.h"
#include "GridData.h" #include "GridData.h"
namespace platform { namespace platform {
using json = nlohmann::json;
struct ConfigGrid { struct ConfigGrid {
std::string model; std::string model;
std::string score; std::string score;
@ -25,7 +27,8 @@ namespace platform {
void save(); void save();
~GridSearch() = default; ~GridSearch() = default;
private: private:
void processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
json results;
struct ConfigGrid config; struct ConfigGrid config;
GridData grid; GridData grid;
}; };

View File

@ -1,6 +1,7 @@
#ifndef PATHS_H #ifndef PATHS_H
#define PATHS_H #define PATHS_H
#include <string> #include <string>
#include <filesystem>
#include "DotEnv.h" #include "DotEnv.h"
namespace platform { namespace platform {
class Paths { class Paths {
@ -8,13 +9,22 @@ namespace platform {
static std::string results() { return "results/"; } static std::string results() { return "results/"; }
static std::string hiddenResults() { return "hidden_results/"; } static std::string hiddenResults() { return "hidden_results/"; }
static std::string excel() { return "excel/"; } static std::string excel() { return "excel/"; }
static std::string cfs() { return "cfs/"; }
static std::string grid() { return "grid/"; } static std::string grid() { return "grid/"; }
static std::string datasets() static std::string datasets()
{ {
auto env = platform::DotEnv(); auto env = platform::DotEnv();
return env.get("source_data"); return env.get("source_data");
} }
static void createPath(const std::string& path)
{
// Create directory if it does not exist
try {
std::filesystem::create_directory(path);
}
catch (std::exception& e) {
throw std::runtime_error("Could not create directory " + path);
}
}
static std::string excelResults() { return "some_results.xlsx"; } static std::string excelResults() { return "some_results.xlsx"; }
}; };
} }

View File

@ -66,6 +66,7 @@ int main(int argc, char** argv)
* Begin Processing * Begin Processing
*/ */
auto env = platform::DotEnv(); auto env = platform::DotEnv();
platform::Paths::createPath(platform::Paths::grid());
config.path = platform::Paths::grid(); config.path = platform::Paths::grid();
auto grid_search = platform::GridSearch(config); auto grid_search = platform::GridSearch(config);
platform::Timer timer; platform::Timer timer;