diff --git a/CMakeLists.txt b/CMakeLists.txt index c53a3a2..a187279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ endif (ENABLE_CLANG_TIDY) add_git_submodule("lib/mdlp") add_git_submodule("lib/argparse") add_git_submodule("lib/json") +add_git_submodule("lib/openXLSX") # Subdirectories # -------------- diff --git a/lib/openXLSX b/lib/openXLSX new file mode 160000 index 0000000..b80da42 --- /dev/null +++ b/lib/openXLSX @@ -0,0 +1 @@ +Subproject commit b80da42d1454f361c29117095ebe1989437db390 diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 80fd20c..a7d5e5c 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -10,6 +10,9 @@ namespace bayesnet { } void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { + // Check if hyperparameters are valid + const vector validKeys = { "repeatSparent", "maxModels", "ascending" }; + checkHyperparameters(validKeys, hyperparameters); if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } @@ -74,7 +77,7 @@ namespace bayesnet { // Step 3.4: Store classifier and its accuracy to weigh its future vote models.push_back(std::move(model)); significanceModels.push_back(significance); - exitCondition = n_models == maxModels; + exitCondition = n_models == maxModels && repeatSparent; } if (featuresUsed.size() != features.size()) { cout << "Warning: BoostAODE did not use all the features" << endl; diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index ff25657..db4a63f 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -152,4 +152,12 @@ namespace bayesnet { { model.dump_cpt(); } + void Classifier::checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters) + { + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } + } } \ No newline at end of file diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 0c2940b..d27e486 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -24,6 +24,7 @@ namespace bayesnet { void checkFitParameters(); virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override; + void checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters); public: Classifier(Network model); virtual ~Classifier() = default; diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 78c6615..ae295c9 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -4,9 +4,9 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) -add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc) -add_executable(manage manage.cc Results.cc Report.cc) +add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") -target_link_libraries(manage "${TORCH_LIBRARIES}") +target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 88e3125..09674ab 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -1,7 +1,7 @@ #include "Experiment.h" #include "Datasets.h" #include "Models.h" -#include "Report.h" +#include "ReportConsole.h" namespace platform { using json = nlohmann::json; @@ -91,7 +91,7 @@ namespace platform { void Experiment::report() { json data = build_json(); - Report report(data); + ReportConsole report(data); report.show(); } @@ -181,11 +181,11 @@ namespace platform { item++; } cout << "end. " << flush; - delete fold; } result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); + result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); result.setDataset(fileName); addResult(result); diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index fdda25a..b19b09f 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -6,6 +6,7 @@ namespace platform { public: static std::string datasets() { return "datasets/"; } static std::string results() { return "results/"; } + static std::string excel() { return "excel/"; } }; } #endif \ No newline at end of file diff --git a/src/Platform/Report.h b/src/Platform/Report.h deleted file mode 100644 index 105785f..0000000 --- a/src/Platform/Report.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef REPORT_H -#define REPORT_H -#include -#include -#include -#include "Colors.h" - -using json = nlohmann::json; -const int MAXL = 128; -namespace platform { - using namespace std; - class Report { - public: - explicit Report(json data_) { data = data_; }; - virtual ~Report() = default; - void show(); - private: - void header(); - void body(); - void footer(); - string fromVector(const string& key); - json data; - double totalScore; // Total score of all results in a report - }; -}; -#endif \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc new file mode 100644 index 0000000..24125f8 --- /dev/null +++ b/src/Platform/ReportBase.cc @@ -0,0 +1,37 @@ +#include +#include +#include "ReportBase.h" +#include "BestResult.h" + + +namespace platform { + string ReportBase::fromVector(const string& key) + { + stringstream oss; + string sep = ""; + oss << "["; + for (auto& item : data[key]) { + oss << sep << item.get(); + sep = ", "; + } + oss << "]"; + return oss.str(); + } + string ReportBase::fVector(const string& title, const json& data, const int width, const int precision) + { + stringstream oss; + string sep = ""; + oss << title << "["; + for (const auto& item : data) { + oss << sep << fixed << setw(width) << setprecision(precision) << item.get(); + sep = ", "; + } + oss << "]"; + return oss.str(); + } + void ReportBase::show() + { + header(); + body(); + } +} \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h new file mode 100644 index 0000000..2acbbc7 --- /dev/null +++ b/src/Platform/ReportBase.h @@ -0,0 +1,23 @@ +#ifndef REPORTBASE_H +#define REPORTBASE_H +#include +#include +#include + +using json = nlohmann::json; +namespace platform { + using namespace std; + class ReportBase { + public: + explicit ReportBase(json data_) { data = data_; }; + virtual ~ReportBase() = default; + void show(); + protected: + json data; + string fromVector(const string& key); + string fVector(const string& title, const json& data, const int width, const int precision); + virtual void header() = 0; + virtual void body() = 0; + }; +}; +#endif \ No newline at end of file diff --git a/src/Platform/Report.cc b/src/Platform/ReportConsole.cc similarity index 81% rename from src/Platform/Report.cc rename to src/Platform/ReportConsole.cc index 5690668..2e3ed0c 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/ReportConsole.cc @@ -1,52 +1,24 @@ #include #include -#include "Report.h" +#include "ReportConsole.h" #include "BestResult.h" namespace platform { - string headerLine(const string& text) - { - int n = MAXL - text.length() - 3; - n = n < 0 ? 0 : n; - return "* " + text + string(n, ' ') + "*\n"; - } - string Report::fromVector(const string& key) - { - stringstream oss; - string sep = ""; - oss << "["; - for (auto& item : data[key]) { - oss << sep << item.get(); - sep = ", "; - } - oss << "]"; - return oss.str(); - } - string fVector(const string& title, const json& data, const int width, const int precision) - { - stringstream oss; - string sep = ""; - oss << title << "["; - for (const auto& item : data) { - oss << sep << fixed << setw(width) << setprecision(precision) << item.get(); - sep = ", "; - } - oss << "]"; - return oss.str(); - } - void Report::show() - { - header(); - body(); - footer(); - } struct separated : numpunct { char do_decimal_point() const { return ','; } char do_thousands_sep() const { return '.'; } string do_grouping() const { return "\03"; } }; - void Report::header() + + string ReportConsole::headerLine(const string& text) + { + int n = MAXL - text.length() - 3; + n = n < 0 ? 0 : n; + return "* " + text + string(n, ' ') + "*\n"; + } + + void ReportConsole::header() { locale mylocale(cout.getloc(), new separated); locale::global(mylocale); @@ -62,12 +34,12 @@ namespace platform { cout << string(MAXL, '*') << endl; cout << endl; } - void Report::body() + void ReportConsole::body() { cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; json lastResult; - totalScore = 0; + double totalScore = 0.0; bool odd = true; for (const auto& r : data["results"]) { auto color = odd ? Colors::CYAN() : Colors::BLUE(); @@ -98,9 +70,11 @@ namespace platform { cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3)); cout << string(MAXL, '*') << endl; + } else { + footer(totalScore); } } - void Report::footer() + void ReportConsole::footer(double totalScore) { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; auto score = data["score_name"].get(); @@ -110,6 +84,5 @@ namespace platform { cout << headerLine(oss.str()); } cout << string(MAXL, '*') << endl << Colors::RESET(); - } } \ No newline at end of file diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h new file mode 100644 index 0000000..5c795b7 --- /dev/null +++ b/src/Platform/ReportConsole.h @@ -0,0 +1,22 @@ +#ifndef REPORTCONSOLE_H +#define REPORTCONSOLE_H +#include +#include +#include "ReportBase.h" +#include "Colors.h" + +namespace platform { + using namespace std; + const int MAXL = 128; + class ReportConsole : public ReportBase{ + public: + explicit ReportConsole(json data_) : ReportBase(data_) {}; + virtual ~ReportConsole() = default; + private: + string headerLine(const string& text); + void header() override; + void body() override; + void footer(double totalScore); + }; +}; +#endif \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc new file mode 100644 index 0000000..bb51dfb --- /dev/null +++ b/src/Platform/ReportExcel.cc @@ -0,0 +1,109 @@ +#include +#include +#include "ReportExcel.h" +#include "BestResult.h" + + +namespace platform { + struct separated : numpunct { + char do_decimal_point() const { return ','; } + + char do_thousands_sep() const { return '.'; } + + string do_grouping() const { return "\03"; } + }; + + void ReportExcel::createFile() + { + doc.create(Paths::excel() + "some_results.xlsx"); + wks = doc.workbook().worksheet("Sheet1"); + wks.setName(data["model"].get()); + } + + void ReportExcel::closeFile() + { + doc.save(); + doc.close(); + } + + void ReportExcel::header() + { + locale mylocale(cout.getloc(), new separated); + locale::global(mylocale); + cout.imbue(mylocale); + stringstream oss; + wks.cell("A1").value().set( + "Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + + " random seeds. " + data["date"].get() + " " + data["time"].get()); + wks.cell("A2").value() = data["title"].get(); + wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " + + (data["stratified"].get() ? "True" : "False"); + oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " + << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); + wks.cell("A4").value() = oss.str(); + wks.cell("A5").value() = "Score is " + data["score_name"].get(); + } + + void ReportExcel::body() + { + auto header = vector( + { "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time", + "Time Std.", "Hyperparameters" }); + int col = 1; + for (const auto& item : header) { + wks.cell(8, col++).value() = item; + } + int row = 9; + col = 1; + json lastResult; + double totalScore = 0.0; + string hyperparameters; + for (const auto& r : data["results"]) { + wks.cell(row, col).value() = r["dataset"].get(); + wks.cell(row, col + 1).value() = r["samples"].get(); + wks.cell(row, col + 2).value() = r["features"].get(); + wks.cell(row, col + 3).value() = r["classes"].get(); + wks.cell(row, col + 4).value() = r["nodes"].get(); + wks.cell(row, col + 5).value() = r["leaves"].get(); + wks.cell(row, col + 6).value() = r["depth"].get(); + wks.cell(row, col + 7).value() = r["score"].get(); + wks.cell(row, col + 8).value() = r["score_std"].get(); + wks.cell(row, col + 9).value() = r["time"].get(); + wks.cell(row, col + 10).value() = r["time_std"].get(); + try { + hyperparameters = r["hyperparameters"].get(); + } + catch (const exception& err) { + stringstream oss; + oss << r["hyperparameters"]; + hyperparameters = oss.str(); + } + wks.cell(row, col + 11).value() = hyperparameters; + lastResult = r; + totalScore += r["score"].get(); + row++; + } + if (data["results"].size() == 1) { + for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { + row++; + col = 1; + wks.cell(row, col).value() = group; + for (double item : lastResult[group]) { + wks.cell(row, ++col).value() = item; + } + } + } else { + footer(totalScore, row); + } + } + + void ReportExcel::footer(double totalScore, int row) + { + auto score = data["score_name"].get(); + if (score == BestResult::scoreName()) { + wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: "; + wks.cell(row + 2, 5).value() = totalScore / BestResult::score(); + } + } +} \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h new file mode 100644 index 0000000..3700681 --- /dev/null +++ b/src/Platform/ReportExcel.h @@ -0,0 +1,25 @@ +#ifndef REPORTEXCEL_H +#define REPORTEXCEL_H +#include +#include "ReportBase.h" +#include "Paths.h" +#include "Colors.h" +namespace platform { + using namespace std; + using namespace OpenXLSX; + const int MAXLL = 128; + class ReportExcel : public ReportBase{ + public: + explicit ReportExcel(json data_) : ReportBase(data_) {createFile();}; + virtual ~ReportExcel() {closeFile();}; + private: + void createFile(); + void closeFile(); + XLDocument doc; + XLWorksheet wks; + void header() override; + void body() override; + void footer(double totalScore, int row); + }; +}; +#endif // !REPORTEXCEL_H \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 818f51e..0440bc1 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -1,7 +1,8 @@ #include #include "platformUtils.h" #include "Results.h" -#include "Report.h" +#include "ReportConsole.h" +#include "ReportExcel.h" #include "BestResult.h" #include "Colors.h" namespace platform { @@ -94,21 +95,26 @@ namespace platform { cout << "Invalid index" << endl; return -1; } - void Results::report(const int index) const + void Results::report(const int index, const bool excelReport) const { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); - Report report(data); - report.show(); + if (excelReport) { + ReportExcel report(data); + report.show(); + } else { + ReportConsole report(data); + report.show(); + } } void Results::menu() { char option; int index; bool finished = false; - string filename, line, options = "qldhsr"; + string filename, line, options = "qldhsre"; while (!finished) { - cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; getline(cin, line); if (line.size() == 0) continue; @@ -119,12 +125,14 @@ namespace platform { } option = line[0]; } else { - index = stoi(line); - if (index >= 0 && index < files.size()) { - report(index); - } else { - cout << "Invalid option" << endl; + if (all_of(line.begin(), line.end(), ::isdigit)) { + index = stoi(line); + if (index >= 0 && index < files.size()) { + report(index, false); + continue; + } } + cout << "Invalid option" << endl; continue; } switch (option) { @@ -164,7 +172,13 @@ namespace platform { index = getIndex("report"); if (index == -1) break; - report(index); + report(index, false); + break; + case 'e': + index = getIndex("excel"); + if (index == -1) + break; + report(index, true); break; default: cout << "Invalid option" << endl; diff --git a/src/Platform/Results.h b/src/Platform/Results.h index e6b1552..39f3c55 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -42,7 +42,7 @@ namespace platform { vector files; void load(); // Loads the list of results void show() const; - void report(const int index) const; + void report(const int index, const bool excelReport) const; int getIndex(const string& intent) const; void menu(); void sortList();