diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index 1241ea8..239cf50 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -36,7 +36,7 @@ namespace platform { return Colors::RESET(); } } - void GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) + double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) { // Get dataset auto [X, y] = datasets.getTensors(fileName); @@ -44,6 +44,8 @@ namespace platform { auto features = datasets.getFeatures(fileName); auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); + double totalScore = 0.0; + int numItems = 0; for (const auto& seed : config.seeds) { std::cout << "(" << seed << ") doing Fold: " << flush; Fold* fold; @@ -51,8 +53,10 @@ namespace platform { fold = new StratifiedKFold(config.n_folds, y, seed); else fold = new KFold(config.n_folds, y.size(0), seed); + double bestScore = 0.0; for (int nfold = 0; nfold < config.n_folds; nfold++) { auto clf = Models::instance()->create(config.model); + clf->setHyperparameters(hyperparameters.get(fileName)); auto [train, test] = fold->getFold(nfold); auto train_t = torch::tensor(train); auto test_t = torch::tensor(test); @@ -60,15 +64,18 @@ namespace platform { auto y_train = y.index({ train_t }); auto X_test = X.index({ "...", test_t }); auto y_test = y.index({ test_t }); - showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); // Train model - // clf->fit(X_train, y_train, features, className, states); + clf->fit(X_train, y_train, features, className, states); + showProgressFold(nfold + 1, getColor(clf->getStatus()), "a"); showProgressFold(nfold + 1, getColor(clf->getStatus()), "b"); + totalScore += clf->score(X_test, y_test); + numItems++; showProgressFold(nfold + 1, getColor(clf->getStatus()), "c"); std::cout << "\b\b\b, \b" << flush; } delete fold; } + return numItems == 0 ? 0.0 : totalScore / numItems; } void GridSearch::go() { @@ -83,12 +90,21 @@ namespace platform { for (const auto& dataset : datasets.getNames()) { std::cout << "- " << setw(20) << left << dataset << " " << right << flush; int num = 0; + double bestScore = 0.0; + json bestHyperparameters; for (const auto& hyperparam_line : grid.getGrid(config.model)) { showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); - processFile(dataset, datasets, hyperparameters); + double score = processFile(dataset, datasets, hyperparameters); + if (score > bestScore) { + bestScore = score; + bestHyperparameters = hyperparam_line; + } } - std::cout << "end." << std::endl; + std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed + << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; + results[dataset]["score"] = bestScore; + results[dataset]["hyperparameters"] = bestHyperparameters; } // Save results save(); @@ -96,7 +112,7 @@ namespace platform { void GridSearch::save() { std::ofstream file(config.output_file); - // file << results.dump(4); + file << results.dump(4); file.close(); } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index 220eccc..6bf9f1a 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -2,11 +2,13 @@ #define GRIDSEARCH_H #include #include +#include #include "Datasets.h" #include "HyperParameters.h" #include "GridData.h" namespace platform { + using json = nlohmann::json; struct ConfigGrid { std::string model; std::string score; @@ -25,7 +27,8 @@ namespace platform { void save(); ~GridSearch() = default; private: - void processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); + double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); + json results; struct ConfigGrid config; GridData grid; }; diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index d3c4422..3f5d135 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -1,6 +1,7 @@ #ifndef PATHS_H #define PATHS_H #include +#include #include "DotEnv.h" namespace platform { class Paths { @@ -8,13 +9,22 @@ namespace platform { static std::string results() { return "results/"; } static std::string hiddenResults() { return "hidden_results/"; } static std::string excel() { return "excel/"; } - static std::string cfs() { return "cfs/"; } static std::string grid() { return "grid/"; } static std::string datasets() { auto env = platform::DotEnv(); return env.get("source_data"); } + static void createPath(const std::string& path) + { + // Create directory if it does not exist + try { + std::filesystem::create_directory(path); + } + catch (std::exception& e) { + throw std::runtime_error("Could not create directory " + path); + } + } static std::string excelResults() { return "some_results.xlsx"; } }; } diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index de905bf..b66050d 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -66,6 +66,7 @@ int main(int argc, char** argv) * Begin Processing */ auto env = platform::DotEnv(); + platform::Paths::createPath(platform::Paths::grid()); config.path = platform::Paths::grid(); auto grid_search = platform::GridSearch(config); platform::Timer timer;