Add generate-fold-files to b_main

This commit is contained in:
2024-05-28 10:52:08 +02:00
parent b34af13eea
commit f5d5c35002
7 changed files with 58 additions and 10 deletions

View File

@@ -47,6 +47,7 @@ void manageArguments(argparse::ArgumentParser& program)
);
program.add_argument("--title").default_value("").help("Experiment title");
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true);
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
@@ -75,7 +76,7 @@ int main(int argc, char** argv)
manageArguments(program);
std::string file_name, model_name, title, hyperparameters_file, datasets_file;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files;
std::vector<int> seeds;
std::vector<std::string> file_names;
std::vector<std::string> filesToTest;
@@ -95,6 +96,7 @@ int main(int argc, char** argv)
hyperparameters_json = json::parse(hyperparameters);
hyperparameters_file = program.get<std::string>("hyper-file");
no_train_score = program.get<bool>("no-train-score");
generate_fold_files = program.get<bool>("generate-fold-files");
if (hyperparameters_file != "" && hyperparameters != "{}") {
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
}
@@ -184,7 +186,7 @@ int main(int argc, char** argv)
}
platform::Timer timer;
timer.start();
experiment.go(filesToTest, quiet, no_train_score);
experiment.go(filesToTest, quiet, no_train_score, generate_fold_files);
experiment.setDuration(timer.getDuration());
if (saveResults) {
experiment.saveResult();

View File

@@ -15,6 +15,12 @@ namespace platform {
auto env = platform::DotEnv();
return env.get("source_data");
}
static std::string experiment_file(const std::string& fileName, bool discretize, bool stratified, int seed, int nfold)
{
std::string disc = discretize ? "_disc_" : "_ndisc_";
std::string strat = stratified ? "strat_" : "nstrat_";
return "datasets_experiment/" + fileName + disc + strat + std::to_string(seed) + "_" + std::to_string(nfold) + ".json";
}
static void createPath(const std::string& path)
{
// Create directory if it does not exist

View File

@@ -23,7 +23,7 @@ namespace platform {
{
std::cout << result.getJson().dump(4) << std::endl;
}
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score)
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files)
{
for (auto fileName : filesToProcess) {
if (fileName.size() > max_name)
@@ -47,7 +47,7 @@ namespace platform {
for (auto fileName : filesToProcess) {
if (!quiet)
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
cross_validation(fileName, quiet, no_train_score);
cross_validation(fileName, quiet, no_train_score, generate_fold_files);
if (!quiet)
std::cout << std::endl;
}
@@ -74,7 +74,45 @@ namespace platform {
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
}
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score)
void generate_files(const std::string& fileName, bool discretize, bool stratified, int seed, int nfold, torch::Tensor X_train, torch::Tensor y_train, torch::Tensor X_test, torch::Tensor y_test, std::vector<int>& train, std::vector<int>& test)
{
std::string file_name = Paths::experiment_file(fileName, discretize, stratified, seed, nfold);
auto file = std::ofstream(file_name);
json output;
output["seed"] = seed;
output["nfold"] = nfold;
output["X_train"] = json::array();
auto n = X_train.size(1);
for (int i = 0; i < X_train.size(0); i++) {
if (X_train.dtype() == torch::kFloat32) {
auto xvf_ptr = X_train.index({ i }).data_ptr<float>();
auto feature = std::vector<float>(xvf_ptr, xvf_ptr + n);
output["X_train"].push_back(feature);
} else {
auto feature = std::vector<int>(X_train.index({ i }).data_ptr<int>(), X_train.index({ i }).data_ptr<int>() + n);
output["X_train"].push_back(feature);
}
}
output["y_train"] = std::vector<int>(y_train.data_ptr<int>(), y_train.data_ptr<int>() + n);
output["X_test"] = json::array();
n = X_test.size(1);
for (int i = 0; i < X_test.size(0); i++) {
if (X_train.dtype() == torch::kFloat32) {
auto xvf_ptr = X_test.index({ i }).data_ptr<float>();
auto feature = std::vector<float>(xvf_ptr, xvf_ptr + n);
output["X_test"].push_back(feature);
} else {
auto feature = std::vector<int>(X_test.index({ i }).data_ptr<int>(), X_test.index({ i }).data_ptr<int>() + n);
output["X_test"].push_back(feature);
}
}
output["y_test"] = std::vector<int>(y_test.data_ptr<int>(), y_test.data_ptr<int>() + n);
output["train"] = train;
output["test"] = test;
file << output.dump(4);
file.close();
}
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
{
auto datasets = Datasets(discretized, Paths::datasets());
// Get dataset
@@ -137,6 +175,8 @@ namespace platform {
auto y_train = y.index({ train_t });
auto X_test = X.index({ "...", test_t });
auto y_test = y.index({ test_t });
if (generate_fold_files)
generate_files(fileName, discretized, stratified, seed, nfold, X_train, y_train, X_test, y_test, train, test);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "a");
// Train model

View File

@@ -28,8 +28,8 @@ namespace platform {
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); result.addSeed(randomSeed); return *this; }
Experiment& setDuration(float duration) { this->result.setDuration(duration); return *this; }
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score);
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files);
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files);
void saveResult();
void show();
void report(bool classification_report = false);

View File

@@ -64,7 +64,7 @@ namespace platform {
void Result::save()
{
std::ofstream file(Paths::results() + "/" + getFilename());
std::ofstream file(Paths::results() + getFilename());
file << data;
file.close();
}