Add classifier notes to console & excel report

This commit is contained in:
2024-02-09 18:08:08 +01:00
parent 0907906ef6
commit 3f3c14e8fc
8 changed files with 37 additions and 5 deletions

2
.gitmodules vendored
View File

@@ -13,6 +13,8 @@
[submodule "lib/mdlp"]
path = lib/mdlp
url = https://github.com/rmontanana/mdlp
update = merge
[submodule "lib/PyClassifiers"]
path = lib/PyClassifiers
url = https://github.com/rmontanana/PyClassifiers
update = merge

View File

@@ -76,6 +76,7 @@ namespace platform {
j["nodes"] = r.getNodes();
j["leaves"] = r.getLeaves();
j["depth"] = r.getDepth();
j["notes"] = r.getNotes();
result["results"].push_back(j);
}
return result;
@@ -176,6 +177,7 @@ namespace platform {
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
auto edges = torch::zeros({ nResults }, torch::kFloat64);
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
std::vector<std::string> notes;
Timer train_timer, test_timer;
int item = 0;
bool first_seed = true;
@@ -214,6 +216,8 @@ namespace platform {
clf->fit(X_train, y_train, features, className, states);
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
auto clf_notes = clf->getNotes();
notes.insert(notes.end(), clf_notes.begin(), clf_notes.end());
nodes[item] = clf->getNumberOfNodes();
edges[item] = clf->getNumberOfEdges();
num_states[item] = clf->getNumberOfStates();
@@ -248,7 +252,7 @@ namespace platform {
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
result.setDataset(fileName);
result.setDataset(fileName).setNotes(notes);
addResult(result);
}
}

View File

@@ -21,9 +21,11 @@ namespace platform {
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
std::vector<double> scores_train, scores_test, times_train, times_test;
std::vector<std::string> notes;
public:
Result() = default;
Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; }
Result& setNotes(const std::vector<std::string>& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; }
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
Result& setSamples(int samples) { this->samples = samples; return *this; }
Result& setFeatures(int features) { this->features = features; return *this; }
@@ -61,6 +63,7 @@ namespace platform {
const float getNodes() const { return nodes; }
const float getLeaves() const { return leaves; }
const float getDepth() const { return depth; }
const std::vector<std::string>& getNotes() const { return notes; }
const std::vector<double>& getScoresTrain() const { return scores_train; }
const std::vector<double>& getScoresTest() const { return scores_test; }
const std::vector<double>& getTimesTrain() const { return times_train; }

View File

@@ -57,7 +57,11 @@ namespace platform {
}
auto color = odd ? Colors::CYAN() : Colors::BLUE();
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::string separator{ " " };
if (r.find("notes") != r.end()) {
separator = r["notes"].size() > 0 ? Colors::YELLOW() + Symbols::notebook + color : " ";
}
std::cout << std::setw(3) << std::right << index++ << separator;
std::cout << std::setw(maxDataset) << std::left << r["dataset"].get<std::string>() << " ";
std::cout << std::setw(6) << std::right << r["samples"].get<int>() << " ";
std::cout << std::setw(5) << std::right << r["features"].get<int>() << " ";
@@ -78,6 +82,14 @@ namespace platform {
}
if (data["results"].size() == 1 || selectedIndex != -1) {
std::cout << std::string(MAXL, '*') << std::endl;
if (lastResult.find("notes") != lastResult.end()) {
if (lastResult["notes"].size() > 0) {
std::cout << headerLine("Notes: ");
for (const auto& note : lastResult["notes"]) {
std::cout << headerLine(note.get<std::string>());
}
}
}
std::cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12));
std::cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12));
std::cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));

View File

@@ -133,6 +133,16 @@ namespace platform {
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
// Show totals if only one dataset is present in the result
if (data["results"].size() == 1) {
row++;
if (lastResult.find("notes") != lastResult.end()) {
if (lastResult["notes"].size() > 0) {
writeString(row++, 1, "Notes: ", "bodyHeader");
for (const auto& note : lastResult["notes"]) {
worksheet_merge_range(worksheet, row, 2, row, 5, note.get<std::string>().c_str(), efectiveStyle("text"));
row++;
}
}
}
for (const std::string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
row++;
col = 1;

View File

@@ -12,6 +12,7 @@ namespace platform {
inline static const std::string down_arrow{ "\u27B4" };
inline static const std::string equal_best{ check_mark };
inline static const std::string better_best{ black_star };
inline static const std::string notebook{ "\U0001F5C8" };
};
}
#endif // !SYMBOLS_H