Add --no-train-score to b_main
This commit is contained in:
@@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
||||
);
|
||||
program.add_argument("--title").default_value("").help("Experiment title");
|
||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
|
||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||
@@ -58,7 +59,7 @@ int main(int argc, char** argv)
|
||||
manageArguments(program);
|
||||
std::string file_name, model_name, title, hyperparameters_file;
|
||||
json hyperparameters_json;
|
||||
bool discretize_dataset, stratified, saveResults, quiet;
|
||||
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
||||
std::vector<int> seeds;
|
||||
std::vector<std::string> filesToTest;
|
||||
int n_folds;
|
||||
@@ -74,6 +75,7 @@ int main(int argc, char** argv)
|
||||
auto hyperparameters = program.get<std::string>("hyperparameters");
|
||||
hyperparameters_json = json::parse(hyperparameters);
|
||||
hyperparameters_file = program.get<std::string>("hyper-file");
|
||||
no_train_score = program.get<bool>("no-train-score");
|
||||
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
||||
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
||||
}
|
||||
@@ -123,7 +125,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
platform::Timer timer;
|
||||
timer.start();
|
||||
experiment.go(filesToTest, quiet);
|
||||
experiment.go(filesToTest, quiet, no_train_score);
|
||||
experiment.setDuration(timer.getDuration());
|
||||
if (saveResults) {
|
||||
experiment.save(platform::Paths::results());
|
||||
|
@@ -101,7 +101,7 @@ namespace platform {
|
||||
std::cout << data.dump(4) << std::endl;
|
||||
}
|
||||
|
||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet)
|
||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score)
|
||||
{
|
||||
for (auto fileName : filesToProcess) {
|
||||
if (fileName.size() > max_name)
|
||||
@@ -122,7 +122,7 @@ namespace platform {
|
||||
for (auto fileName : filesToProcess) {
|
||||
if (!quiet)
|
||||
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
||||
cross_validation(fileName, quiet);
|
||||
cross_validation(fileName, quiet, no_train_score);
|
||||
if (!quiet)
|
||||
std::cout << std::endl;
|
||||
}
|
||||
@@ -150,7 +150,7 @@ namespace platform {
|
||||
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||
|
||||
}
|
||||
void Experiment::cross_validation(const std::string& fileName, bool quiet)
|
||||
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score)
|
||||
{
|
||||
auto datasets = Datasets(discretized, Paths::datasets());
|
||||
// Get dataset
|
||||
@@ -218,8 +218,10 @@ namespace platform {
|
||||
edges[item] = clf->getNumberOfEdges();
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
train_time[item] = train_timer.getDuration();
|
||||
double accuracy_train_value = 0.0;
|
||||
// Score train
|
||||
auto accuracy_train_value = clf->score(X_train, y_train);
|
||||
if (!no_train_score)
|
||||
accuracy_train_value = clf->score(X_train, y_train);
|
||||
// Test model
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
|
@@ -85,8 +85,8 @@ namespace platform {
|
||||
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
|
||||
std::string get_file_name();
|
||||
void save(const std::string& path);
|
||||
void cross_validation(const std::string& fileName, bool quiet);
|
||||
void go(std::vector<std::string> filesToProcess, bool quiet);
|
||||
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score);
|
||||
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
|
||||
void show();
|
||||
void report();
|
||||
private:
|
||||
|
Reference in New Issue
Block a user