From 26b649ebae1a7729e114e094d42a2ee9adf9273d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 22 Oct 2023 20:03:34 +0200 Subject: [PATCH] Refactor ManageResults and CommandParser --- README.md | 2 + src/Platform/CMakeLists.txt | 4 +- src/Platform/CommandParser.cc | 79 +++++++++++ src/Platform/CommandParser.h | 21 +++ src/Platform/DotEnv.h | 11 -- src/Platform/ManageResults.cc | 212 ++++++++++++++++++++++++++++++ src/Platform/ManageResults.h | 31 +++++ src/Platform/Paths.h | 1 + src/Platform/Results.cc | 238 +++------------------------------- src/Platform/Results.h | 45 +++---- src/Platform/Utils.h | 11 ++ src/Platform/b_manage.cc | 13 +- 12 files changed, 405 insertions(+), 263 deletions(-) create mode 100644 src/Platform/CommandParser.cc create mode 100644 src/Platform/CommandParser.h create mode 100644 src/Platform/ManageResults.cc create mode 100644 src/Platform/ManageResults.h diff --git a/README.md b/README.md index ad2660c..8eb2318 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # BayesNet +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + Bayesian Network Classifier with libtorch from scratch ## 0. Setup diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index a253283..5c8c317 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,10 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/libxlsxwriter/include) + add_executable(b_main b_main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Models.cc ReportConsole.cc ReportBase.cc) -add_executable(b_manage b_manage.cc Results.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) +add_executable(b_manage b_manage.cc Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_list b_list.cc Datasets.cc Dataset.cc) add_executable(b_best b_best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ExcelFile.cc) + target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}") diff --git a/src/Platform/CommandParser.cc b/src/Platform/CommandParser.cc new file mode 100644 index 0000000..5123e66 --- /dev/null +++ b/src/Platform/CommandParser.cc @@ -0,0 +1,79 @@ +#include "CommandParser.h" +#include +#include +#include +#include "Colors.h" +#include "Utils.h" + +namespace platform { + void CommandParser::messageError(const string& message) + { + cout << Colors::RED() << message << Colors::RESET() << endl; + } + pair CommandParser::parse(const string& color, const vector>& options, const char defaultCommand) + { + bool finished = false; + while (!finished) { + stringstream oss; + string line; + oss << color << "Choose option ("; + bool first = true; + for (auto& option : options) { + if (first) { + first = false; + } else { + oss << ", "; + } + oss << get(option) << "=" << get(option); + } + oss << "): "; + cout << oss.str(); + getline(cin, line); + cout << Colors::RESET(); + line = trim(line); + if (line.size() == 0) + continue; + if (all_of(line.begin(), line.end(), ::isdigit)) { + command = defaultCommand; + index = stoi(line); + finished = true; + break; + } + bool found = false; + for (auto& option : options) { + if (line[0] == get(option)) { + found = true; + // it's a match + line.erase(line.begin()); + line = trim(line); + if (get(option)) { + // The option requires a value + if (line.size() == 0) { + messageError("Option " + get(option) + " requires a value"); + break; + } + try { + index = stoi(line); + } + catch (const std::invalid_argument& ia) { + messageError("Invalid value: " + line); + break; + } + } else { + if (line.size() > 0) { + messageError("option " + get(option) + " doesn't accept values"); + break; + } + } + command = get(option); + finished = true; + break; + } + } + if (!found) { + messageError("I don't know " + line); + } + } + return { command, index }; + } +} /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/CommandParser.h b/src/Platform/CommandParser.h new file mode 100644 index 0000000..1e71d34 --- /dev/null +++ b/src/Platform/CommandParser.h @@ -0,0 +1,21 @@ +#ifndef COMMAND_PARSER_H +#define COMMAND_PARSER_H +#include +#include +#include +using namespace std; + +namespace platform { + class CommandParser { + public: + CommandParser() = default; + pair parse(const string& color, const vector>& options, const char defaultCommand); + char getCommand() const { return command; }; + int getIndex() const { return index; }; + private: + void messageError(const string& message); + char command; + int index; + }; +} /* namespace platform */ +#endif /* COMMAND_PARSER_H */ \ No newline at end of file diff --git a/src/Platform/DotEnv.h b/src/Platform/DotEnv.h index 7d5ee2b..8b7a0cf 100644 --- a/src/Platform/DotEnv.h +++ b/src/Platform/DotEnv.h @@ -13,17 +13,6 @@ namespace platform { class DotEnv { private: std::map env; - std::string trim(const std::string& str) - { - std::string result = str; - result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) { - return !std::isspace(ch); - })); - result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) { - return !std::isspace(ch); - }).base(), result.end()); - return result; - } public: DotEnv() { diff --git a/src/Platform/ManageResults.cc b/src/Platform/ManageResults.cc new file mode 100644 index 0000000..39652db --- /dev/null +++ b/src/Platform/ManageResults.cc @@ -0,0 +1,212 @@ +#include "ManageResults.h" +#include "CommandParser.h" +#include +#include +#include "Colors.h" +#include "CLocale.h" +#include "Paths.h" +#include "ReportConsole.h" +#include "ReportExcel.h" + +namespace platform { + + ManageResults::ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare) : + numFiles{ numFiles }, complete{ complete }, partial{ partial }, compare{ compare }, results(Results(Paths::results(), model, score, complete, partial, compare)) + { + indexList = true; + openExcel = false; + workbook = NULL; + if (numFiles == 0) { + this->numFiles = results.size(); + } + } + void ManageResults::doMenu() + { + results.sortDate(); + list(); + menu(); + if (openExcel) { + workbook_close(workbook); + } + cout << Colors::RESET() << "Done!" << endl; + } + void ManageResults::list() + { + if (results.empty()) { + cout << Colors::MAGENTA() << "No results found!" << Colors::RESET() << endl; + exit(0); + } + auto temp = ConfigLocale(); + cout << Colors::GREEN() << "Results found: " << numFiles << endl; + cout << "-------------------" << endl; + if (complete) { + cout << Colors::MAGENTA() << "Only listing complete results" << endl; + } + if (partial) { + cout << Colors::MAGENTA() << "Only listing partial results" << endl; + } + auto i = 0; + cout << Colors::GREEN() << " # Date Model Score Name Score C/P Duration Title" << endl; + cout << "=== ========== ============ =========== =========== === ========= =============================================================" << endl; + bool odd = true; + for (auto& result : results) { + auto color = odd ? Colors::BLUE() : Colors::CYAN(); + cout << color << setw(3) << fixed << right << i++ << " "; + cout << result.to_string() << endl; + if (i == numFiles) { + break; + } + odd = !odd; + } + } + bool ManageResults::confirmAction(const string& intent, const string& fileName) const + { + string color; + if (intent == "delete") { + color = Colors::RED(); + } else { + color = Colors::YELLOW(); + } + string line; + bool finished = false; + while (!finished) { + cout << color << "Really want to " << intent << " " << fileName << "? (y/n): "; + getline(cin, line); + finished = line.size() == 1 && (tolower(line[0]) == 'y' || tolower(line[0] == 'n')); + } + if (tolower(line[0]) == 'y') { + return true; + } + cout << "Not done!" << endl; + return false; + } + void ManageResults::report(const int index, const bool excelReport) + { + cout << Colors::YELLOW() << "Reporting " << results.at(index).getFilename() << endl; + auto data = results.at(index).load(); + if (excelReport) { + ReportExcel reporter(data, compare, workbook); + reporter.show(); + openExcel = true; + workbook = reporter.getWorkbook(); + } else { + ReportConsole reporter(data, compare); + reporter.show(); + } + } + void ManageResults::showIndex(const int index, const int idx) + { + // Show a dataset result inside a report + auto data = results.at(index).load(); + if (idx < 0 or idx >= static_cast(data["results"].size())) { + cout << "Invalid index" << endl; + return; + } + cout << Colors::YELLOW() << "Showing " << results.at(index).getFilename() << endl; + ReportConsole reporter(data, compare, idx); + reporter.show(); + } + void ManageResults::sortList() + { + cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; + string line; + char option; + getline(cin, line); + if (line.size() == 0) + return; + if (line.size() > 1) { + cout << "Invalid option" << endl; + return; + } + option = line[0]; + switch (option) { + case 'd': + results.sortDate(); + break; + case 's': + results.sortScore(); + break; + case 'u': + results.sortDuration(); + break; + case 'm': + results.sortModel(); + break; + default: + cout << "Invalid option" << endl; + } + } + void ManageResults::menu() + { + char option; + int index, subIndex; + bool finished = false; + string filename; + // tuple + vector> mainOptions = { + {"quit", 'q', false}, + {"list", 'l', false}, + {"delete", 'd', true}, + {"hide", 'h', true}, + {"sort", 's', false}, + {"report", 'r', true}, + {"excel", 'e', true} + }; + vector> listOptions = { + {"report", 'r', true}, + {"list", 'l', false}, + {"quit", 'q', false} + }; + auto parser = CommandParser(); + while (!finished) { + if (indexList) { + tie(option, index) = parser.parse(Colors::GREEN(), mainOptions, 'r'); + } else { + tie(option, subIndex) = parser.parse(Colors::MAGENTA(), listOptions, 'r'); + } + switch (option) { + case 'q': + finished = true; + break; + case 'l': + list(); + indexList = true; + break; + case 'd': + filename = results.at(index).getFilename(); + if (!confirmAction("delete", filename)) + break; + cout << "Deleting " << filename << endl; + results.deleteResult(index); + cout << "File: " + filename + " deleted!" << endl; + list(); + break; + case 'h': + filename = results.at(index).getFilename(); + if (!confirmAction("hide", filename)) + break; + filename = results.at(index).getFilename(); + cout << "Hiding " << filename << endl; + results.hideResult(index, Paths::hiddenResults()); + cout << "File: " + filename + " hidden! (moved to " << Paths::hiddenResults() << ")" << endl; + list(); + break; + case 's': + sortList(); + list(); + break; + case 'r': + if (indexList) { + report(index, false); + indexList = false; + } else { + showIndex(index, subIndex); + } + break; + case 'e': + report(index, true); + break; + } + } + } +} /* namespace platform */ diff --git a/src/Platform/ManageResults.h b/src/Platform/ManageResults.h new file mode 100644 index 0000000..3766970 --- /dev/null +++ b/src/Platform/ManageResults.h @@ -0,0 +1,31 @@ +#ifndef MANAGE_RESULTS_H +#define MANAGE_RESULTS_H +#include "Results.h" +#include "xlsxwriter.h" + +namespace platform { + class ManageResults { + public: + ManageResults(int numFiles, const string& model, const string& score, bool complete, bool partial, bool compare); + ~ManageResults() = default; + void doMenu(); + private: + void list(); + bool confirmAction(const string& intent, const string& fileName) const; + void report(const int index, const bool excelReport); + void showIndex(const int index, const int idx); + void sortList(); + void menu(); + int numFiles; + bool indexList; + bool openExcel; + bool complete; + bool partial; + bool compare; + Results results; + lxw_workbook* workbook; + }; + +} + +#endif /* MANAGE_RESULTS_H */ \ No newline at end of file diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index 16d459c..c27dd3d 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -6,6 +6,7 @@ namespace platform { class Paths { public: static std::string results() { return "results/"; } + static std::string hiddenResults() { return "hidden_results/"; } static std::string excel() { return "excel/"; } static std::string cfs() { return "cfs/"; } static std::string datasets() diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 1f0cd66..6f73a4c 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -1,11 +1,14 @@ -#include #include "Results.h" -#include "ReportConsole.h" -#include "ReportExcel.h" +#include #include "BestScore.h" -#include "Colors.h" -#include "CLocale.h" + namespace platform { + Results::Results(const string& path, const string& model, const string& score, bool complete, bool partial, bool compare) : + path(path), model(model), scoreName(score), complete(complete), partial(partial), compare(compare) + { + load(); + maxModel = (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size(); + }; void Results::load() { using std::filesystem::directory_iterator; @@ -20,212 +23,22 @@ namespace platform { files.push_back(result); } } - if (max == 0) { - max = files.size(); - } } - void Results::show() const + void Results::hideResult(int index, const string& pathHidden) { - auto temp = ConfigLocale(); - cout << Colors::GREEN() << "Results found: " << files.size() << endl; - cout << "-------------------" << endl; - if (complete) { - cout << Colors::MAGENTA() << "Only listing complete results" << endl; - } - if (partial) { - cout << Colors::MAGENTA() << "Only listing partial results" << endl; - } - auto i = 0; - cout << Colors::GREEN() << " # Date Model Score Name Score C/P Duration Title" << endl; - cout << "=== ========== ============ =========== =========== === ========= =============================================================" << endl; - bool odd = true; - for (const auto& result : files) { - auto color = odd ? Colors::BLUE() : Colors::CYAN(); - cout << color << setw(3) << fixed << right << i++ << " "; - cout << result.to_string() << endl; - if (i == max && max != 0) { - break; - } - odd = !odd; - } + auto filename = files.at(index).getFilename(); + rename((path + "/" + filename).c_str(), (pathHidden + "/" + filename).c_str()); + files.erase(files.begin() + index); } - int Results::getIndex(const string& intent) const + void Results::deleteResult(int index) { - string color; - if (intent == "delete") { - color = Colors::RED(); - } else { - color = Colors::YELLOW(); - } - cout << color << "Choose result to " << intent << " (cancel=-1): "; - string line; - getline(cin, line); - int index = stoi(line); - if (index >= -1 && index < static_cast(files.size())) { - return index; - } - cout << "Invalid index" << endl; - return -1; + auto filename = files.at(index).getFilename(); + remove((path + "/" + filename).c_str()); + files.erase(files.begin() + index); } - void Results::report(const int index, const bool excelReport) + int Results::size() const { - cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; - auto data = files.at(index).load(); - if (excelReport) { - ReportExcel reporter(data, compare, workbook); - reporter.show(); - openExcel = true; - workbook = reporter.getWorkbook(); - } else { - ReportConsole reporter(data, compare); - reporter.show(); - } - } - void Results::showIndex(const int index, const int idx) const - { - auto data = files.at(index).load(); - if (idx < 0 or idx >= static_cast(data["results"].size())) { - cout << "Invalid index" << endl; - return; - } - cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl; - ReportConsole reporter(data, compare, idx); - reporter.show(); - } - void Results::menu() - { - char option; - int index; - bool finished = false; - string color, context; - string filename, line, options = "qldhsre"; - while (!finished) { - if (indexList) { - color = Colors::GREEN(); - context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; - options = "qldhsre"; - } else { - color = Colors::MAGENTA(); - context = " (quit='q', list='l'): "; - options = "ql"; - } - cout << Colors::RESET() << color; - - cout << "Choose option " << context; - getline(cin, line); - if (line.size() == 0) - continue; - if (options.find(line[0]) != string::npos) { - if (line.size() > 1) { - cout << "Invalid option" << endl; - continue; - } - option = line[0]; - } else { - if (all_of(line.begin(), line.end(), ::isdigit)) { - int idx = stoi(line); - if (indexList) { - // The value is about the files list - index = idx; - if (index >= 0 && index < max) { - report(index, false); - indexList = false; - continue; - } - } else { - // The value is about the result showed on screen - showIndex(index, idx); - continue; - } - } - cout << "Invalid option" << endl; - continue; - } - switch (option) { - case 'q': - finished = true; - break; - case 'l': - show(); - indexList = true; - break; - case 'd': - index = getIndex("delete"); - if (index == -1) - break; - filename = files[index].getFilename(); - cout << "Deleting " << filename << endl; - remove((path + "/" + filename).c_str()); - files.erase(files.begin() + index); - cout << "File: " + filename + " deleted!" << endl; - show(); - indexList = true; - break; - case 'h': - index = getIndex("hide"); - if (index == -1) - break; - filename = files[index].getFilename(); - cout << "Hiding " << filename << endl; - rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str()); - files.erase(files.begin() + index); - show(); - menu(); - indexList = true; - break; - case 's': - sortList(); - indexList = true; - show(); - break; - case 'r': - index = getIndex("report"); - if (index == -1) - break; - indexList = false; - report(index, false); - break; - case 'e': - index = getIndex("excel"); - if (index == -1) - break; - indexList = true; - report(index, true); - break; - default: - cout << "Invalid option" << endl; - } - } - } - void Results::sortList() - { - cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; - string line; - char option; - getline(cin, line); - if (line.size() == 0) - return; - if (line.size() > 1) { - cout << "Invalid option" << endl; - return; - } - option = line[0]; - switch (option) { - case 'd': - sortDate(); - break; - case 's': - sortScore(); - break; - case 'u': - sortDuration(); - break; - case 'm': - sortModel(); - break; - default: - cout << "Invalid option" << endl; - } + return files.size(); } void Results::sortDate() { @@ -251,19 +64,8 @@ namespace platform { return a.getScore() > b.getScore(); }); } - void Results::manage() + bool Results::empty() const { - if (files.size() == 0) { - cout << "No results found!" << endl; - exit(0); - } - sortDate(); - show(); - menu(); - if (openExcel) { - workbook_close(workbook); - } - cout << Colors::RESET() << "Done!" << endl; + return files.empty(); } - } \ No newline at end of file diff --git a/src/Platform/Results.h b/src/Platform/Results.h index b322cfb..c6d207c 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -1,6 +1,5 @@ #ifndef RESULTS_H #define RESULTS_H -#include "xlsxwriter.h" #include #include #include @@ -12,35 +11,29 @@ namespace platform { class Results { public: - Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) : - path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare) - { - load(); - }; - void manage(); - private: - string path; - int max; - string model; - string scoreName; - bool complete; - bool partial; - bool indexList = true; - bool openExcel = false; - bool compare; - lxw_workbook* workbook = NULL; - vector files; - void load(); // Loads the list of results - void show() const; - void report(const int index, const bool excelReport); - void showIndex(const int index, const int idx) const; - int getIndex(const string& intent) const; - void menu(); - void sortList(); + Results(const string& path, const string& model, const string& score, bool complete, bool partial, bool compare); void sortDate(); void sortScore(); void sortModel(); void sortDuration(); + int maxModelSize() const { return maxModel; }; + void hideResult(int index, const string& path); + void deleteResult(int index); + int size() const; + bool empty() const; + vector::iterator begin() { return files.begin(); }; + vector::iterator end() { return files.end(); }; + Result& at(int index) { return files.at(index); }; + private: + string path; + string model; + string scoreName; + bool complete; + bool partial; + bool compare; + int maxModel; + vector files; + void load(); // Loads the list of results }; }; diff --git a/src/Platform/Utils.h b/src/Platform/Utils.h index 3e24f05..6b6f599 100644 --- a/src/Platform/Utils.h +++ b/src/Platform/Utils.h @@ -15,5 +15,16 @@ namespace platform { } return result; } + static std::string trim(const std::string& str) + { + std::string result = str; + result.erase(result.begin(), std::find_if(result.begin(), result.end(), [](int ch) { + return !std::isspace(ch); + })); + result.erase(std::find_if(result.rbegin(), result.rend(), [](int ch) { + return !std::isspace(ch); + }).base(), result.end()); + return result; + } } #endif \ No newline at end of file diff --git a/src/Platform/b_manage.cc b/src/Platform/b_manage.cc index 7e95473..ef62868 100644 --- a/src/Platform/b_manage.cc +++ b/src/Platform/b_manage.cc @@ -1,7 +1,6 @@ #include #include -#include "Paths.h" -#include "Results.h" +#include "ManageResults.h" using namespace std; @@ -37,15 +36,15 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) int main(int argc, char** argv) { auto program = manageArguments(argc, argv); - auto number = program.get("number"); - auto model = program.get("model"); - auto score = program.get("score"); + int number = program.get("number"); + string model = program.get("model"); + string score = program.get("score"); auto complete = program.get("complete"); auto partial = program.get("partial"); auto compare = program.get("compare"); if (complete) partial = false; - auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare); - results.manage(); + auto manager = platform::ManageResults(number, model, score, complete, partial, compare); + manager.doMenu(); return 0; }