Complete first step
This commit is contained in:
parent
bbe5302ab1
commit
8b7b59d42b
@ -36,7 +36,7 @@ namespace platform {
|
|||||||
return Colors::RESET();
|
return Colors::RESET();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
||||||
{
|
{
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
@ -44,6 +44,8 @@ namespace platform {
|
|||||||
auto features = datasets.getFeatures(fileName);
|
auto features = datasets.getFeatures(fileName);
|
||||||
auto samples = datasets.getNSamples(fileName);
|
auto samples = datasets.getNSamples(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
|
double totalScore = 0.0;
|
||||||
|
int numItems = 0;
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
@ -51,8 +53,10 @@ namespace platform {
|
|||||||
fold = new StratifiedKFold(config.n_folds, y, seed);
|
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||||
else
|
else
|
||||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
|
double bestScore = 0.0;
|
||||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
auto clf = Models::instance()->create(config.model);
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
auto test_t = torch::tensor(test);
|
auto test_t = torch::tensor(test);
|
||||||
@ -60,15 +64,18 @@ namespace platform {
|
|||||||
auto y_train = y.index({ train_t });
|
auto y_train = y.index({ train_t });
|
||||||
auto X_test = X.index({ "...", test_t });
|
auto X_test = X.index({ "...", test_t });
|
||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
|
||||||
// Train model
|
// Train model
|
||||||
// clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
totalScore += clf->score(X_test, y_test);
|
||||||
|
numItems++;
|
||||||
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
std::cout << "\b\b\b, \b" << flush;
|
std::cout << "\b\b\b, \b" << flush;
|
||||||
}
|
}
|
||||||
delete fold;
|
delete fold;
|
||||||
}
|
}
|
||||||
|
return numItems == 0 ? 0.0 : totalScore / numItems;
|
||||||
}
|
}
|
||||||
void GridSearch::go()
|
void GridSearch::go()
|
||||||
{
|
{
|
||||||
@ -83,12 +90,21 @@ namespace platform {
|
|||||||
for (const auto& dataset : datasets.getNames()) {
|
for (const auto& dataset : datasets.getNames()) {
|
||||||
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
int num = 0;
|
int num = 0;
|
||||||
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
|
for (const auto& hyperparam_line : grid.getGrid(config.model)) {
|
||||||
showProgressComb(++num, totalComb, Colors::CYAN());
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
processFile(dataset, datasets, hyperparameters);
|
double score = processFile(dataset, datasets, hyperparameters);
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::cout << "end." << std::endl;
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
|
results[dataset]["score"] = bestScore;
|
||||||
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
}
|
}
|
||||||
// Save results
|
// Save results
|
||||||
save();
|
save();
|
||||||
@ -96,7 +112,7 @@ namespace platform {
|
|||||||
void GridSearch::save()
|
void GridSearch::save()
|
||||||
{
|
{
|
||||||
std::ofstream file(config.output_file);
|
std::ofstream file(config.output_file);
|
||||||
// file << results.dump(4);
|
file << results.dump(4);
|
||||||
file.close();
|
file.close();
|
||||||
}
|
}
|
||||||
} /* namespace platform */
|
} /* namespace platform */
|
@ -2,11 +2,13 @@
|
|||||||
#define GRIDSEARCH_H
|
#define GRIDSEARCH_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "Datasets.h"
|
#include "Datasets.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
#include "GridData.h"
|
#include "GridData.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
using json = nlohmann::json;
|
||||||
struct ConfigGrid {
|
struct ConfigGrid {
|
||||||
std::string model;
|
std::string model;
|
||||||
std::string score;
|
std::string score;
|
||||||
@ -25,7 +27,8 @@ namespace platform {
|
|||||||
void save();
|
void save();
|
||||||
~GridSearch() = default;
|
~GridSearch() = default;
|
||||||
private:
|
private:
|
||||||
void processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
|
json results;
|
||||||
struct ConfigGrid config;
|
struct ConfigGrid config;
|
||||||
GridData grid;
|
GridData grid;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#ifndef PATHS_H
|
#ifndef PATHS_H
|
||||||
#define PATHS_H
|
#define PATHS_H
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <filesystem>
|
||||||
#include "DotEnv.h"
|
#include "DotEnv.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Paths {
|
class Paths {
|
||||||
@ -8,13 +9,22 @@ namespace platform {
|
|||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
static std::string hiddenResults() { return "hidden_results/"; }
|
static std::string hiddenResults() { return "hidden_results/"; }
|
||||||
static std::string excel() { return "excel/"; }
|
static std::string excel() { return "excel/"; }
|
||||||
static std::string cfs() { return "cfs/"; }
|
|
||||||
static std::string grid() { return "grid/"; }
|
static std::string grid() { return "grid/"; }
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
return env.get("source_data");
|
return env.get("source_data");
|
||||||
}
|
}
|
||||||
|
static void createPath(const std::string& path)
|
||||||
|
{
|
||||||
|
// Create directory if it does not exist
|
||||||
|
try {
|
||||||
|
std::filesystem::create_directory(path);
|
||||||
|
}
|
||||||
|
catch (std::exception& e) {
|
||||||
|
throw std::runtime_error("Could not create directory " + path);
|
||||||
|
}
|
||||||
|
}
|
||||||
static std::string excelResults() { return "some_results.xlsx"; }
|
static std::string excelResults() { return "some_results.xlsx"; }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,7 @@ int main(int argc, char** argv)
|
|||||||
* Begin Processing
|
* Begin Processing
|
||||||
*/
|
*/
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
|
platform::Paths::createPath(platform::Paths::grid());
|
||||||
config.path = platform::Paths::grid();
|
config.path = platform::Paths::grid();
|
||||||
auto grid_search = platform::GridSearch(config);
|
auto grid_search = platform::GridSearch(config);
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
|
Loading…
Reference in New Issue
Block a user