Add --no-train-score to b_main

This commit is contained in:
2024-01-30 13:11:16 +01:00
parent 7dbef9fc36
commit e336d39cfb
3 changed files with 12 additions and 8 deletions

View File

@@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program)
);
program.add_argument("--title").default_value("").help("Experiment title");
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
@@ -58,7 +59,7 @@ int main(int argc, char** argv)
manageArguments(program);
std::string file_name, model_name, title, hyperparameters_file;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
std::vector<int> seeds;
std::vector<std::string> filesToTest;
int n_folds;
@@ -74,6 +75,7 @@ int main(int argc, char** argv)
auto hyperparameters = program.get<std::string>("hyperparameters");
hyperparameters_json = json::parse(hyperparameters);
hyperparameters_file = program.get<std::string>("hyper-file");
no_train_score = program.get<bool>("no-train-score");
if (hyperparameters_file != "" && hyperparameters != "{}") {
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
}
@@ -123,7 +125,7 @@ int main(int argc, char** argv)
}
platform::Timer timer;
timer.start();
experiment.go(filesToTest, quiet);
experiment.go(filesToTest, quiet, no_train_score);
experiment.setDuration(timer.getDuration());
if (saveResults) {
experiment.save(platform::Paths::results());

View File

@@ -101,7 +101,7 @@ namespace platform {
std::cout << data.dump(4) << std::endl;
}
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet)
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score)
{
for (auto fileName : filesToProcess) {
if (fileName.size() > max_name)
@@ -122,7 +122,7 @@ namespace platform {
for (auto fileName : filesToProcess) {
if (!quiet)
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
cross_validation(fileName, quiet);
cross_validation(fileName, quiet, no_train_score);
if (!quiet)
std::cout << std::endl;
}
@@ -150,7 +150,7 @@ namespace platform {
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
}
void Experiment::cross_validation(const std::string& fileName, bool quiet)
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score)
{
auto datasets = Datasets(discretized, Paths::datasets());
// Get dataset
@@ -218,8 +218,10 @@ namespace platform {
edges[item] = clf->getNumberOfEdges();
num_states[item] = clf->getNumberOfStates();
train_time[item] = train_timer.getDuration();
double accuracy_train_value = 0.0;
// Score train
auto accuracy_train_value = clf->score(X_train, y_train);
if (!no_train_score)
accuracy_train_value = clf->score(X_train, y_train);
// Test model
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c");

View File

@@ -85,8 +85,8 @@ namespace platform {
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
std::string get_file_name();
void save(const std::string& path);
void cross_validation(const std::string& fileName, bool quiet);
void go(std::vector<std::string> filesToProcess, bool quiet);
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score);
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
void show();
void report();
private: