gridsearch #13
3
.gitignore
vendored
3
.gitignore
vendored
@ -32,8 +32,7 @@
|
|||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
build/**
|
build/**
|
||||||
build_debug/**
|
build_*/**
|
||||||
build_release/**
|
|
||||||
*.dSYM/**
|
*.dSYM/**
|
||||||
cmake-build*/**
|
cmake-build*/**
|
||||||
.idea
|
.idea
|
||||||
|
12
Makefile
12
Makefile
@ -4,7 +4,7 @@ SHELL := /bin/bash
|
|||||||
|
|
||||||
f_release = build_release
|
f_release = build_release
|
||||||
f_debug = build_debug
|
f_debug = build_debug
|
||||||
app_targets = b_best b_list b_main b_manage
|
app_targets = b_best b_list b_main b_manage b_grid
|
||||||
test_targets = unit_tests_bayesnet unit_tests_platform
|
test_targets = unit_tests_bayesnet unit_tests_platform
|
||||||
n_procs = -j 16
|
n_procs = -j 16
|
||||||
|
|
||||||
@ -35,11 +35,13 @@ dest ?= ${HOME}/bin
|
|||||||
install: ## Copy binary files to bin folder
|
install: ## Copy binary files to bin folder
|
||||||
@echo "Destination folder: $(dest)"
|
@echo "Destination folder: $(dest)"
|
||||||
make buildr
|
make buildr
|
||||||
|
@echo "*******************************************"
|
||||||
@echo ">>> Copying files to $(dest)"
|
@echo ">>> Copying files to $(dest)"
|
||||||
@cp $(f_release)/src/Platform/b_main $(dest)
|
@echo "*******************************************"
|
||||||
@cp $(f_release)/src/Platform/b_list $(dest)
|
@for item in $(app_targets); do \
|
||||||
@cp $(f_release)/src/Platform/b_manage $(dest)
|
echo ">>> Copying $$item" ; \
|
||||||
@cp $(f_release)/src/Platform/b_best $(dest)
|
cp $(f_release)/src/Platform/$$item $(dest) ; \
|
||||||
|
done
|
||||||
|
|
||||||
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
dependency: ## Create a dependency graph diagram of the project (build/dependency.png)
|
||||||
@echo ">>> Creating dependency graph diagram of the project...";
|
@echo ">>> Creating dependency graph diagram of the project...";
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include <std::string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <argparse/argparse.hpp>
|
#include <argparse/argparse.hpp>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
@ -8,12 +8,14 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include)
|
||||||
include_directories(${Python3_INCLUDE_DIRS})
|
include_directories(${Python3_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
||||||
|
add_executable(b_grid b_grid.cc GridSearch.cc GridData.cc HyperParameters.cc Folding.cc Datasets.cc Dataset.cc)
|
||||||
|
add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
|
||||||
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc)
|
add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc)
|
||||||
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
||||||
add_executable(b_list b_list.cc Datasets.cc Dataset.cc)
|
|
||||||
add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc)
|
|
||||||
|
|
||||||
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
|
|
||||||
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
|
||||||
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
|
target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}" "${TORCH_LIBRARIES}" ArffFiles mdlp)
|
||||||
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(b_grid BayesNet PyWrap)
|
||||||
|
target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
|
target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}" PyWrap)
|
||||||
|
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
|
@ -133,7 +133,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void Experiment::cross_validation(const std::string& fileName, bool quiet)
|
void Experiment::cross_validation(const std::string& fileName, bool quiet)
|
||||||
{
|
{
|
||||||
auto datasets = platform::Datasets(discretized, Paths::datasets());
|
auto datasets = Datasets(discretized, Paths::datasets());
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
|
@ -3,30 +3,16 @@
|
|||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <chrono>
|
|
||||||
#include "Folding.h"
|
#include "Folding.h"
|
||||||
#include "BaseClassifier.h"
|
#include "BaseClassifier.h"
|
||||||
#include "HyperParameters.h"
|
#include "HyperParameters.h"
|
||||||
#include "TAN.h"
|
#include "TAN.h"
|
||||||
#include "KDB.h"
|
#include "KDB.h"
|
||||||
#include "AODE.h"
|
#include "AODE.h"
|
||||||
|
#include "Timer.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
class Timer {
|
|
||||||
private:
|
|
||||||
std::chrono::high_resolution_clock::time_point begin;
|
|
||||||
public:
|
|
||||||
Timer() = default;
|
|
||||||
~Timer() = default;
|
|
||||||
void start() { begin = std::chrono::high_resolution_clock::now(); }
|
|
||||||
double getDuration()
|
|
||||||
{
|
|
||||||
std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now();
|
|
||||||
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
|
|
||||||
return time_span.count();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
class Result {
|
class Result {
|
||||||
private:
|
private:
|
||||||
std::string dataset, model_version;
|
std::string dataset, model_version;
|
||||||
|
55
src/Platform/GridData.cc
Normal file
55
src/Platform/GridData.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#include "GridData.h"
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
GridData::GridData(const std::string& fileName)
|
||||||
|
{
|
||||||
|
std::ifstream resultData(fileName);
|
||||||
|
if (resultData.is_open()) {
|
||||||
|
grid = json::parse(resultData);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unable to open input file. [" + fileName + "]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int GridData::computeNumCombinations(const json& line)
|
||||||
|
{
|
||||||
|
int numCombinations = 1;
|
||||||
|
for (const auto& item : line.items()) {
|
||||||
|
numCombinations *= item.value().size();
|
||||||
|
}
|
||||||
|
return numCombinations;
|
||||||
|
}
|
||||||
|
int GridData::getNumCombinations()
|
||||||
|
{
|
||||||
|
int numCombinations = 0;
|
||||||
|
for (const auto& line : grid) {
|
||||||
|
numCombinations += computeNumCombinations(line);
|
||||||
|
}
|
||||||
|
return numCombinations;
|
||||||
|
}
|
||||||
|
json GridData::generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination)
|
||||||
|
{
|
||||||
|
if (index == last) {
|
||||||
|
// If we reached the end of input, store the current combination
|
||||||
|
output.push_back(currentCombination);
|
||||||
|
return currentCombination;
|
||||||
|
}
|
||||||
|
const auto& key = index.key();
|
||||||
|
const auto& values = index.value();
|
||||||
|
for (const auto& value : values) {
|
||||||
|
auto combination = currentCombination;
|
||||||
|
combination[key] = value;
|
||||||
|
json::iterator nextIndex = index;
|
||||||
|
generateCombinations(++nextIndex, last, output, combination);
|
||||||
|
}
|
||||||
|
return currentCombination;
|
||||||
|
}
|
||||||
|
std::vector<json> GridData::getGrid()
|
||||||
|
{
|
||||||
|
auto result = std::vector<json>();
|
||||||
|
for (json line : grid) {
|
||||||
|
generateCombinations(line.begin(), line.end(), result, json({}));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
22
src/Platform/GridData.h
Normal file
22
src/Platform/GridData.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#ifndef GRIDDATA_H
|
||||||
|
#define GRIDDATA_H
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::json;
|
||||||
|
class GridData {
|
||||||
|
public:
|
||||||
|
explicit GridData(const std::string& fileName);
|
||||||
|
~GridData() = default;
|
||||||
|
std::vector<json> getGrid();
|
||||||
|
int getNumCombinations();
|
||||||
|
private:
|
||||||
|
json generateCombinations(json::iterator index, const json::iterator last, std::vector<json>& output, json currentCombination);
|
||||||
|
int computeNumCombinations(const json& line);
|
||||||
|
json grid;
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
||||||
|
#endif /* GRIDDATA_H */
|
130
src/Platform/GridSearch.cc
Normal file
130
src/Platform/GridSearch.cc
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include "GridSearch.h"
|
||||||
|
#include "Models.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
#include "Folding.h"
|
||||||
|
#include "Colors.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
GridSearch::GridSearch(struct ConfigGrid& config) : config(config)
|
||||||
|
{
|
||||||
|
this->config.output_file = config.path + "grid_" + config.model + "_output.json";
|
||||||
|
this->config.input_file = config.path + "grid_" + config.model + "_input.json";
|
||||||
|
}
|
||||||
|
void showProgressComb(const int num, const int total, const std::string& color)
|
||||||
|
{
|
||||||
|
int spaces = int(log(total) / log(10)) + 1;
|
||||||
|
int magic = 37 + 2 * spaces;
|
||||||
|
std::string prefix = num == 1 ? "" : string(magic, '\b') + string(magic + 1, ' ') + string(magic + 1, '\b');
|
||||||
|
std::cout << prefix << color << "(" << setw(spaces) << num << "/" << setw(spaces) << total << ") " << Colors::RESET() << flush;
|
||||||
|
}
|
||||||
|
void showProgressFold(int fold, const std::string& color, const std::string& phase)
|
||||||
|
{
|
||||||
|
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
|
||||||
|
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||||
|
}
|
||||||
|
std::string getColor(bayesnet::status_t status)
|
||||||
|
{
|
||||||
|
switch (status) {
|
||||||
|
case bayesnet::NORMAL:
|
||||||
|
return Colors::GREEN();
|
||||||
|
case bayesnet::WARNING:
|
||||||
|
return Colors::YELLOW();
|
||||||
|
case bayesnet::ERROR:
|
||||||
|
return Colors::RED();
|
||||||
|
default:
|
||||||
|
return Colors::RESET();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters)
|
||||||
|
{
|
||||||
|
// Get dataset
|
||||||
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
|
auto states = datasets.getStates(fileName);
|
||||||
|
auto features = datasets.getFeatures(fileName);
|
||||||
|
auto samples = datasets.getNSamples(fileName);
|
||||||
|
auto className = datasets.getClassName(fileName);
|
||||||
|
double totalScore = 0.0;
|
||||||
|
int numItems = 0;
|
||||||
|
for (const auto& seed : config.seeds) {
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
|
Fold* fold;
|
||||||
|
if (config.stratified)
|
||||||
|
fold = new StratifiedKFold(config.n_folds, y, seed);
|
||||||
|
else
|
||||||
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
|
double bestScore = 0.0;
|
||||||
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
|
auto [train, test] = fold->getFold(nfold);
|
||||||
|
auto train_t = torch::tensor(train);
|
||||||
|
auto test_t = torch::tensor(test);
|
||||||
|
auto X_train = X.index({ "...", train_t });
|
||||||
|
auto y_train = y.index({ train_t });
|
||||||
|
auto X_test = X.index({ "...", test_t });
|
||||||
|
auto y_test = y.index({ test_t });
|
||||||
|
// Train model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
// Test model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
totalScore += clf->score(X_test, y_test);
|
||||||
|
numItems++;
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "\b\b\b, \b" << flush;
|
||||||
|
}
|
||||||
|
delete fold;
|
||||||
|
}
|
||||||
|
return numItems == 0 ? 0.0 : totalScore / numItems;
|
||||||
|
}
|
||||||
|
void GridSearch::go()
|
||||||
|
{
|
||||||
|
// Load datasets
|
||||||
|
auto datasets = Datasets(config.discretize, Paths::datasets());
|
||||||
|
// Create model
|
||||||
|
std::cout << "***************** Starting Gridsearch *****************" << std::endl;
|
||||||
|
std::cout << "input file=" << config.input_file << std::endl;
|
||||||
|
auto grid = GridData(config.input_file);
|
||||||
|
auto totalComb = grid.getNumCombinations();
|
||||||
|
std::cout << "* Doing " << totalComb << " combinations for each dataset/seed/fold" << std::endl;
|
||||||
|
// Generate hyperparameters grid & run gridsearch
|
||||||
|
// Check each combination of hyperparameters for each dataset and each seed
|
||||||
|
for (const auto& dataset : datasets.getNames()) {
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "- " << setw(20) << left << dataset << " " << right << flush;
|
||||||
|
int num = 0;
|
||||||
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
|
for (const auto& hyperparam_line : grid.getGrid()) {
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressComb(++num, totalComb, Colors::CYAN());
|
||||||
|
auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line);
|
||||||
|
double score = processFile(dataset, datasets, hyperparameters);
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!config.quiet) {
|
||||||
|
std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed
|
||||||
|
<< bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl;
|
||||||
|
}
|
||||||
|
results[dataset]["score"] = bestScore;
|
||||||
|
results[dataset]["hyperparameters"] = bestHyperparameters;
|
||||||
|
}
|
||||||
|
// Save results
|
||||||
|
save();
|
||||||
|
std::cout << "***************** Ending Gridsearch *******************" << std::endl;
|
||||||
|
}
|
||||||
|
void GridSearch::save() const
|
||||||
|
{
|
||||||
|
std::ofstream file(config.output_file);
|
||||||
|
file << results.dump(4);
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
} /* namespace platform */
|
36
src/Platform/GridSearch.h
Normal file
36
src/Platform/GridSearch.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#ifndef GRIDSEARCH_H
|
||||||
|
#define GRIDSEARCH_H
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "Datasets.h"
|
||||||
|
#include "HyperParameters.h"
|
||||||
|
#include "GridData.h"
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
using json = nlohmann::json;
|
||||||
|
struct ConfigGrid {
|
||||||
|
std::string model;
|
||||||
|
std::string score;
|
||||||
|
std::string path;
|
||||||
|
std::string input_file;
|
||||||
|
std::string output_file;
|
||||||
|
bool quiet;
|
||||||
|
bool discretize;
|
||||||
|
bool stratified;
|
||||||
|
int n_folds;
|
||||||
|
std::vector<int> seeds;
|
||||||
|
};
|
||||||
|
class GridSearch {
|
||||||
|
public:
|
||||||
|
explicit GridSearch(struct ConfigGrid& config);
|
||||||
|
void go();
|
||||||
|
void save() const;
|
||||||
|
~GridSearch() = default;
|
||||||
|
private:
|
||||||
|
double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters);
|
||||||
|
json results;
|
||||||
|
struct ConfigGrid config;
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
||||||
|
#endif /* GRIDSEARCH_H */
|
@ -1,6 +1,7 @@
|
|||||||
#ifndef PATHS_H
|
#ifndef PATHS_H
|
||||||
#define PATHS_H
|
#define PATHS_H
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <filesystem>
|
||||||
#include "DotEnv.h"
|
#include "DotEnv.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Paths {
|
class Paths {
|
||||||
@ -8,12 +9,22 @@ namespace platform {
|
|||||||
static std::string results() { return "results/"; }
|
static std::string results() { return "results/"; }
|
||||||
static std::string hiddenResults() { return "hidden_results/"; }
|
static std::string hiddenResults() { return "hidden_results/"; }
|
||||||
static std::string excel() { return "excel/"; }
|
static std::string excel() { return "excel/"; }
|
||||||
static std::string cfs() { return "cfs/"; }
|
static std::string grid() { return "grid/"; }
|
||||||
static std::string datasets()
|
static std::string datasets()
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
return env.get("source_data");
|
return env.get("source_data");
|
||||||
}
|
}
|
||||||
|
static void createPath(const std::string& path)
|
||||||
|
{
|
||||||
|
// Create directory if it does not exist
|
||||||
|
try {
|
||||||
|
std::filesystem::create_directory(path);
|
||||||
|
}
|
||||||
|
catch (std::exception& e) {
|
||||||
|
throw std::runtime_error("Could not create directory " + path);
|
||||||
|
}
|
||||||
|
}
|
||||||
static std::string excelResults() { return "some_results.xlsx"; }
|
static std::string excelResults() { return "some_results.xlsx"; }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
34
src/Platform/Timer.h
Normal file
34
src/Platform/Timer.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#ifndef TIMER_H
|
||||||
|
#define TIMER_H
|
||||||
|
#include <chrono>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace platform {
|
||||||
|
class Timer {
|
||||||
|
private:
|
||||||
|
std::chrono::high_resolution_clock::time_point begin;
|
||||||
|
std::chrono::high_resolution_clock::time_point end;
|
||||||
|
public:
|
||||||
|
Timer() = default;
|
||||||
|
~Timer() = default;
|
||||||
|
void start() { begin = std::chrono::high_resolution_clock::now(); }
|
||||||
|
void stop() { end = std::chrono::high_resolution_clock::now(); }
|
||||||
|
double getDuration()
|
||||||
|
{
|
||||||
|
stop();
|
||||||
|
std::chrono::duration<double> time_span = std::chrono::duration_cast<std::chrono::duration<double >> (end - begin);
|
||||||
|
return time_span.count();
|
||||||
|
}
|
||||||
|
std::string getDurationString()
|
||||||
|
{
|
||||||
|
double duration = getDuration();
|
||||||
|
double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration;
|
||||||
|
std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s";
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << std::setw(7) << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit << " ";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} /* namespace platform */
|
||||||
|
#endif /* TIMER_H */
|
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("best");
|
argparse::ArgumentParser program("b_sbest");
|
||||||
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)");
|
||||||
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied");
|
||||||
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
|
program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true);
|
||||||
|
81
src/Platform/b_grid.cc
Normal file
81
src/Platform/b_grid.cc
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <argparse/argparse.hpp>
|
||||||
|
#include "DotEnv.h"
|
||||||
|
#include "Models.h"
|
||||||
|
#include "modelRegister.h"
|
||||||
|
#include "GridSearch.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
#include "Timer.h"
|
||||||
|
|
||||||
|
|
||||||
|
argparse::ArgumentParser manageArguments(std::string program_name)
|
||||||
|
{
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
argparse::ArgumentParser program(program_name);
|
||||||
|
program.add_argument("-m", "--model")
|
||||||
|
.help("Model to use " + platform::Models::instance()->tostring())
|
||||||
|
.action([](const std::string& value) {
|
||||||
|
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
||||||
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
|
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||||
|
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
|
||||||
|
try {
|
||||||
|
auto k = stoi(value);
|
||||||
|
if (k < 2) {
|
||||||
|
throw std::runtime_error("Number of folds must be greater than 1");
|
||||||
|
}
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
catch (const runtime_error& err) {
|
||||||
|
throw std::runtime_error(err.what());
|
||||||
|
}
|
||||||
|
catch (...) {
|
||||||
|
throw std::runtime_error("Number of folds must be an integer");
|
||||||
|
}});
|
||||||
|
auto seed_values = env.getSeeds();
|
||||||
|
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
||||||
|
return program;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
auto program = manageArguments("b_grid");
|
||||||
|
struct platform::ConfigGrid config;
|
||||||
|
try {
|
||||||
|
program.parse_args(argc, argv);
|
||||||
|
config.model = program.get<std::string>("model");
|
||||||
|
config.score = program.get<std::string>("score");
|
||||||
|
config.discretize = program.get<bool>("discretize");
|
||||||
|
config.stratified = program.get<bool>("stratified");
|
||||||
|
config.n_folds = program.get<int>("folds");
|
||||||
|
config.quiet = program.get<bool>("quiet");
|
||||||
|
config.seeds = program.get<std::vector<int>>("seeds");
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cerr << err.what() << std::endl;
|
||||||
|
cerr << program;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
* Begin Processing
|
||||||
|
*/
|
||||||
|
auto env = platform::DotEnv();
|
||||||
|
platform::Paths::createPath(platform::Paths::grid());
|
||||||
|
config.path = platform::Paths::grid();
|
||||||
|
auto grid_search = platform::GridSearch(config);
|
||||||
|
platform::Timer timer;
|
||||||
|
timer.start();
|
||||||
|
grid_search.go();
|
||||||
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||||
|
grid_search.save();
|
||||||
|
std::cout << "Done!" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
@ -11,10 +11,10 @@
|
|||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments()
|
argparse::ArgumentParser manageArguments(std::string program_name)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
argparse::ArgumentParser program("main");
|
argparse::ArgumentParser program(program_name);
|
||||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
||||||
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
|
||||||
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
|
||||||
@ -61,7 +61,7 @@ int main(int argc, char** argv)
|
|||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
std::vector<std::string> filesToTest;
|
std::vector<std::string> filesToTest;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
auto program = manageArguments();
|
auto program = manageArguments("b_main");
|
||||||
try {
|
try {
|
||||||
program.parse_args(argc, argv);
|
program.parse_args(argc, argv);
|
||||||
file_name = program.get<std::string>("dataset");
|
file_name = program.get<std::string>("dataset");
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||||
{
|
{
|
||||||
argparse::ArgumentParser program("manage");
|
argparse::ArgumentParser program("b_manage");
|
||||||
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
||||||
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
||||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||||
|
Loading…
Reference in New Issue
Block a user