diff --git a/.gitmodules b/.gitmodules index da340e7..7ba28d1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,6 +13,8 @@ [submodule "lib/mdlp"] path = lib/mdlp url = https://github.com/rmontanana/mdlp + update = merge [submodule "lib/PyClassifiers"] path = lib/PyClassifiers url = https://github.com/rmontanana/PyClassifiers + update = merge diff --git a/lib/PyClassifiers b/lib/PyClassifiers index f46f6dc..3226fb6 160000 --- a/lib/PyClassifiers +++ b/lib/PyClassifiers @@ -1 +1 @@ -Subproject commit f46f6dcbb270413ec9760b5de7abb81f6a932df6 +Subproject commit 3226fb671b056c7cec695827515bed57935c08d1 diff --git a/lib/argparse b/lib/argparse index 69dabd8..1b3abd9 160000 --- a/lib/argparse +++ b/lib/argparse @@ -1 +1 @@ -Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6 +Subproject commit 1b3abd9b929c8d4aed08632824374cf8a55f5a74 diff --git a/src/Platform/modules/Experiment.cc b/src/Platform/modules/Experiment.cc index 1b7d891..891cbb7 100644 --- a/src/Platform/modules/Experiment.cc +++ b/src/Platform/modules/Experiment.cc @@ -76,6 +76,7 @@ namespace platform { j["nodes"] = r.getNodes(); j["leaves"] = r.getLeaves(); j["depth"] = r.getDepth(); + j["notes"] = r.getNotes(); result["results"].push_back(j); } return result; @@ -162,7 +163,7 @@ namespace platform { if (!quiet) { std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } - // Prepare Result + // Prepare Resu lt auto result = Result(); auto [values, counts] = at::_unique(y); result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); @@ -176,6 +177,7 @@ namespace platform { auto nodes = torch::zeros({ nResults }, torch::kFloat64); auto edges = torch::zeros({ nResults }, torch::kFloat64); auto num_states = torch::zeros({ nResults }, torch::kFloat64); + std::vector notes; Timer train_timer, test_timer; int item = 0; bool first_seed = true; @@ -214,6 +216,8 @@ namespace platform { clf->fit(X_train, y_train, features, className, states); if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "b"); + auto clf_notes = clf->getNotes(); + notes.insert(notes.end(), clf_notes.begin(), clf_notes.end()); nodes[item] = clf->getNumberOfNodes(); edges[item] = clf->getNumberOfEdges(); num_states[item] = clf->getNumberOfStates(); @@ -248,7 +252,7 @@ namespace platform { result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); - result.setDataset(fileName); + result.setDataset(fileName).setNotes(notes); addResult(result); } } \ No newline at end of file diff --git a/src/Platform/modules/Experiment.h b/src/Platform/modules/Experiment.h index 252c7c9..53f7e85 100644 --- a/src/Platform/modules/Experiment.h +++ b/src/Platform/modules/Experiment.h @@ -21,9 +21,11 @@ namespace platform { double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 }; float nodes{ 0 }, leaves{ 0 }, depth{ 0 }; std::vector scores_train, scores_test, times_train, times_test; + std::vector notes; public: Result() = default; Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; } + Result& setNotes(const std::vector& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; } Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; } Result& setSamples(int samples) { this->samples = samples; return *this; } Result& setFeatures(int features) { this->features = features; return *this; } @@ -61,6 +63,7 @@ namespace platform { const float getNodes() const { return nodes; } const float getLeaves() const { return leaves; } const float getDepth() const { return depth; } + const std::vector& getNotes() const { return notes; } const std::vector& getScoresTrain() const { return scores_train; } const std::vector& getScoresTest() const { return scores_test; } const std::vector& getTimesTrain() const { return times_train; } diff --git a/src/Platform/modules/ReportConsole.cc b/src/Platform/modules/ReportConsole.cc index 723c4df..7f6ef0d 100644 --- a/src/Platform/modules/ReportConsole.cc +++ b/src/Platform/modules/ReportConsole.cc @@ -57,7 +57,11 @@ namespace platform { } auto color = odd ? Colors::CYAN() : Colors::BLUE(); std::cout << color; - std::cout << std::setw(3) << std::right << index++ << " "; + std::string separator{ " " }; + if (r.find("notes") != r.end()) { + separator = r["notes"].size() > 0 ? Colors::YELLOW() + Symbols::notebook + color : " "; + } + std::cout << std::setw(3) << std::right << index++ << separator; std::cout << std::setw(maxDataset) << std::left << r["dataset"].get() << " "; std::cout << std::setw(6) << std::right << r["samples"].get() << " "; std::cout << std::setw(5) << std::right << r["features"].get() << " "; @@ -78,6 +82,14 @@ namespace platform { } if (data["results"].size() == 1 || selectedIndex != -1) { std::cout << std::string(MAXL, '*') << std::endl; + if (lastResult.find("notes") != lastResult.end()) { + if (lastResult["notes"].size() > 0) { + std::cout << headerLine("Notes: "); + for (const auto& note : lastResult["notes"]) { + std::cout << headerLine(note.get()); + } + } + } std::cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12)); std::cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12)); std::cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); diff --git a/src/Platform/modules/ReportExcel.cc b/src/Platform/modules/ReportExcel.cc index addbf4c..c3b6f77 100644 --- a/src/Platform/modules/ReportExcel.cc +++ b/src/Platform/modules/ReportExcel.cc @@ -133,6 +133,16 @@ namespace platform { worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL); // Show totals if only one dataset is present in the result if (data["results"].size() == 1) { + row++; + if (lastResult.find("notes") != lastResult.end()) { + if (lastResult["notes"].size() > 0) { + writeString(row++, 1, "Notes: ", "bodyHeader"); + for (const auto& note : lastResult["notes"]) { + worksheet_merge_range(worksheet, row, 2, row, 5, note.get().c_str(), efectiveStyle("text")); + row++; + } + } + } for (const std::string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { row++; col = 1; diff --git a/src/Platform/modules/Symbols.h b/src/Platform/modules/Symbols.h index 5a8c9be..3aa837e 100644 --- a/src/Platform/modules/Symbols.h +++ b/src/Platform/modules/Symbols.h @@ -12,6 +12,7 @@ namespace platform { inline static const std::string down_arrow{ "\u27B4" }; inline static const std::string equal_best{ check_mark }; inline static const std::string better_best{ black_star }; + inline static const std::string notebook{ "\U0001F5C8" }; }; } #endif // !SYMBOLS_H \ No newline at end of file