diff --git a/.gitmodules b/.gitmodules index aecd798..83c2fb7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "lib/catch2"] path = lib/catch2 url = https://github.com/catchorg/Catch2.git +[submodule "lib/argparse"] + path = lib/argparse + url = https://github.com/p-ranav/argparse diff --git a/CMakeLists.txt b/CMakeLists.txt index d236e8a..c285764 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,7 @@ include(CodeCoverage) # include(FetchContent) add_git_submodule("lib/mdlp") add_git_submodule("lib/catch2") +add_git_submodule("lib/argparse") # Subdirectories # -------------- diff --git a/lib/argparse b/lib/argparse new file mode 160000 index 0000000..b0930ab --- /dev/null +++ b/lib/argparse @@ -0,0 +1 @@ +Subproject commit b0930ab0288185815d6dc67af59de7014a6272f7 diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt index 41ea8c0..0953c7a 100644 --- a/sample/CMakeLists.txt +++ b/sample/CMakeLists.txt @@ -2,5 +2,6 @@ include_directories(${BayesNet_SOURCE_DIR}/src/Platform) include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) +include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) add_executable(BayesNetSample sample.cc) target_link_libraries(BayesNetSample BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/sample/sample.cc b/sample/sample.cc index 4a19b10..d2191f7 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include "ArffFiles.h" #include "Network.h" #include "BayesMetrics.h" @@ -15,68 +15,7 @@ using namespace std; -const string PATH = "data/"; - -/* print a description of all supported options */ -void usage(const char* path) -{ - /* take only the last portion of the path */ - const char* basename = strrchr(path, '/'); - basename = basename ? basename + 1 : path; - - cout << "usage: " << basename << "[OPTION]" << endl; - cout << " -h, --help\t\t Print this help and exit." << endl; - cout - << " -f, --file[=FILENAME]\t {diabetes, glass, iris, kdd_JapaneseVowels, letter, liver-disorders, mfeat-factors}." - << endl; - cout << " -p, --path[=FILENAME]\t folder where the data files are located, default " << PATH << endl; - cout << " -m, --model={AODE, KDB, SPODE, TAN}\t " << endl; -} - -tuple parse_arguments(int argc, char** argv) -{ - string file_name; - string model_name; - string path = PATH; - const vector long_options = { - {"help", no_argument, nullptr, 'h'}, - {"file", required_argument, nullptr, 'f'}, - {"path", required_argument, nullptr, 'p'}, - {"model", required_argument, nullptr, 'm'}, - {nullptr, no_argument, nullptr, 0} - }; - while (true) { - const auto c = getopt_long(argc, argv, "hf:p:m:", long_options.data(), nullptr); - if (c == -1) - break; - switch (c) { - case 'h': - usage(argv[0]); - exit(0); - case 'f': - file_name = string(optarg); - break; - case 'm': - model_name = string(optarg); - break; - case 'p': - path = optarg; - if (path.back() != '/') - path += '/'; - break; - case '?': - usage(argv[0]); - exit(1); - default: - abort(); - } - } - if (file_name.empty()) { - usage(argv[0]); - exit(1); - } - return make_tuple(file_name, path, model_name); -} +const string PATH = "../../data/"; inline constexpr auto hash_conv(const std::string_view sv) { @@ -117,7 +56,7 @@ bool file_exists(const std::string& name) } } -tuple get_options(int argc, char** argv) +int main(int argc, char** argv) { map datasets = { {"diabetes", true}, @@ -129,35 +68,60 @@ tuple get_options(int argc, char** argv) {"liver-disorders", true}, {"mfeat-factors", true}, }; - vector models = { "AODE", "KDB", "SPODE", "TAN" }; - string file_name; - string path; - string model_name; - tie(file_name, path, model_name) = parse_arguments(argc, argv); - if (datasets.find(file_name) == datasets.end()) { - cout << "Invalid file name: " << file_name << endl; - usage(argv[0]); + auto valid_datasets = vector(); + for (auto dataset : datasets) { + valid_datasets.push_back(dataset.first); + } + argparse::ArgumentParser program("BayesNetSample"); + program.add_argument("-f", "--file") + .help("Dataset file name") + .action([valid_datasets](const std::string& value) { + if (find(valid_datasets.begin(), valid_datasets.end(), value) != valid_datasets.end()) { + return value; + } + throw runtime_error("file must be one of {diabetes, ecoli, glass, iris, kdd_JapaneseVowels, letter, liver-disorders, mfeat-factors}"); + } + ); + program.add_argument("-p", "--path") + .help(" folder where the data files are located, default") + .default_value(string{ PATH } + ); + program.add_argument("-m", "--model") + .help("Model to use {AODE, KDB, SPODE, TAN}") + .action([](const std::string& value) { + static const vector choices = { "AODE", "KDB", "SPODE", "TAN" }; + if (find(choices.begin(), choices.end(), value) != choices.end()) { + return value; + } + throw runtime_error("Model must be one of {AODE, KDB, SPODE, TAN}"); + } + ); + program.add_argument("--discretize").default_value(false).implicit_value(true); + bool class_last, discretize_dataset; + string model_name, file_name, path, complete_file_name; + try { + program.parse_args(argc, argv); + file_name = program.get("file"); + path = program.get("path"); + model_name = program.get("model"); + discretize_dataset = program.get("discretize"); + complete_file_name = path + file_name + ".arff"; + class_last = datasets[file_name]; + if (!file_exists(complete_file_name)) { + throw runtime_error("Data File " + path + file_name + ".arff" + " does not exist"); + } + } + catch (const exception& err) { + cerr << err.what() << endl; + cerr << program; exit(1); } - if (!file_exists(path + file_name + ".arff")) { - cout << "Data File " << path + file_name + ".arff" << " does not exist" << endl; - usage(argv[0]); - exit(1); - } - if (find(models.begin(), models.end(), model_name) == models.end()) { - cout << "Invalid model name: " << model_name << endl; - usage(argv[0]); - exit(1); - } - return { file_name, path, model_name }; -} -int main(int argc, char** argv) -{ - string file_name, path, model_name; - tie(file_name, path, model_name) = get_options(argc, argv); + /* + * Begin Processing + */ auto handler = ArffFiles(); - handler.load(path + file_name + ".arff"); + handler.load(complete_file_name, class_last); // Get Dataset X, y vector& X = handler.getX(); mdlp::labels_t& y = handler.getY();