Add --no-train-score to b_main
This commit is contained in:
@@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
|||||||
);
|
);
|
||||||
program.add_argument("--title").default_value("").help("Experiment title");
|
program.add_argument("--title").default_value("").help("Experiment title");
|
||||||
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||||
|
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
|
||||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||||
@@ -58,7 +59,7 @@ int main(int argc, char** argv)
|
|||||||
manageArguments(program);
|
manageArguments(program);
|
||||||
std::string file_name, model_name, title, hyperparameters_file;
|
std::string file_name, model_name, title, hyperparameters_file;
|
||||||
json hyperparameters_json;
|
json hyperparameters_json;
|
||||||
bool discretize_dataset, stratified, saveResults, quiet;
|
bool discretize_dataset, stratified, saveResults, quiet, no_train_score;
|
||||||
std::vector<int> seeds;
|
std::vector<int> seeds;
|
||||||
std::vector<std::string> filesToTest;
|
std::vector<std::string> filesToTest;
|
||||||
int n_folds;
|
int n_folds;
|
||||||
@@ -74,6 +75,7 @@ int main(int argc, char** argv)
|
|||||||
auto hyperparameters = program.get<std::string>("hyperparameters");
|
auto hyperparameters = program.get<std::string>("hyperparameters");
|
||||||
hyperparameters_json = json::parse(hyperparameters);
|
hyperparameters_json = json::parse(hyperparameters);
|
||||||
hyperparameters_file = program.get<std::string>("hyper-file");
|
hyperparameters_file = program.get<std::string>("hyper-file");
|
||||||
|
no_train_score = program.get<bool>("no-train-score");
|
||||||
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
||||||
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
||||||
}
|
}
|
||||||
@@ -123,7 +125,7 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
platform::Timer timer;
|
platform::Timer timer;
|
||||||
timer.start();
|
timer.start();
|
||||||
experiment.go(filesToTest, quiet);
|
experiment.go(filesToTest, quiet, no_train_score);
|
||||||
experiment.setDuration(timer.getDuration());
|
experiment.setDuration(timer.getDuration());
|
||||||
if (saveResults) {
|
if (saveResults) {
|
||||||
experiment.save(platform::Paths::results());
|
experiment.save(platform::Paths::results());
|
||||||
|
@@ -101,7 +101,7 @@ namespace platform {
|
|||||||
std::cout << data.dump(4) << std::endl;
|
std::cout << data.dump(4) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet)
|
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score)
|
||||||
{
|
{
|
||||||
for (auto fileName : filesToProcess) {
|
for (auto fileName : filesToProcess) {
|
||||||
if (fileName.size() > max_name)
|
if (fileName.size() > max_name)
|
||||||
@@ -122,7 +122,7 @@ namespace platform {
|
|||||||
for (auto fileName : filesToProcess) {
|
for (auto fileName : filesToProcess) {
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
||||||
cross_validation(fileName, quiet);
|
cross_validation(fileName, quiet, no_train_score);
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@@ -150,7 +150,7 @@ namespace platform {
|
|||||||
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||||
|
|
||||||
}
|
}
|
||||||
void Experiment::cross_validation(const std::string& fileName, bool quiet)
|
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score)
|
||||||
{
|
{
|
||||||
auto datasets = Datasets(discretized, Paths::datasets());
|
auto datasets = Datasets(discretized, Paths::datasets());
|
||||||
// Get dataset
|
// Get dataset
|
||||||
@@ -218,8 +218,10 @@ namespace platform {
|
|||||||
edges[item] = clf->getNumberOfEdges();
|
edges[item] = clf->getNumberOfEdges();
|
||||||
num_states[item] = clf->getNumberOfStates();
|
num_states[item] = clf->getNumberOfStates();
|
||||||
train_time[item] = train_timer.getDuration();
|
train_time[item] = train_timer.getDuration();
|
||||||
|
double accuracy_train_value = 0.0;
|
||||||
// Score train
|
// Score train
|
||||||
auto accuracy_train_value = clf->score(X_train, y_train);
|
if (!no_train_score)
|
||||||
|
accuracy_train_value = clf->score(X_train, y_train);
|
||||||
// Test model
|
// Test model
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
|
@@ -85,8 +85,8 @@ namespace platform {
|
|||||||
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
|
Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; }
|
||||||
std::string get_file_name();
|
std::string get_file_name();
|
||||||
void save(const std::string& path);
|
void save(const std::string& path);
|
||||||
void cross_validation(const std::string& fileName, bool quiet);
|
void cross_validation(const std::string& fileName, bool quiet, bool no_train_score);
|
||||||
void go(std::vector<std::string> filesToProcess, bool quiet);
|
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
|
||||||
void show();
|
void show();
|
||||||
void report();
|
void report();
|
||||||
private:
|
private:
|
||||||
|
Reference in New Issue
Block a user