Add classifier notes to console & excel report
This commit is contained in:
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -13,6 +13,8 @@
|
||||
[submodule "lib/mdlp"]
|
||||
path = lib/mdlp
|
||||
url = https://github.com/rmontanana/mdlp
|
||||
update = merge
|
||||
[submodule "lib/PyClassifiers"]
|
||||
path = lib/PyClassifiers
|
||||
url = https://github.com/rmontanana/PyClassifiers
|
||||
update = merge
|
||||
|
Submodule lib/PyClassifiers updated: f46f6dcbb2...3226fb671b
Submodule lib/argparse updated: 69dabd88a8...1b3abd9b92
@@ -76,6 +76,7 @@ namespace platform {
|
||||
j["nodes"] = r.getNodes();
|
||||
j["leaves"] = r.getLeaves();
|
||||
j["depth"] = r.getDepth();
|
||||
j["notes"] = r.getNotes();
|
||||
result["results"].push_back(j);
|
||||
}
|
||||
return result;
|
||||
@@ -162,7 +163,7 @@ namespace platform {
|
||||
if (!quiet) {
|
||||
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
||||
}
|
||||
// Prepare Result
|
||||
// Prepare Resu lt
|
||||
auto result = Result();
|
||||
auto [values, counts] = at::_unique(y);
|
||||
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||
@@ -176,6 +177,7 @@ namespace platform {
|
||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||
std::vector<std::string> notes;
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
bool first_seed = true;
|
||||
@@ -214,6 +216,8 @@ namespace platform {
|
||||
clf->fit(X_train, y_train, features, className, states);
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||
auto clf_notes = clf->getNotes();
|
||||
notes.insert(notes.end(), clf_notes.begin(), clf_notes.end());
|
||||
nodes[item] = clf->getNumberOfNodes();
|
||||
edges[item] = clf->getNumberOfEdges();
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
@@ -248,7 +252,7 @@ namespace platform {
|
||||
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||
result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||
result.setDataset(fileName);
|
||||
result.setDataset(fileName).setNotes(notes);
|
||||
addResult(result);
|
||||
}
|
||||
}
|
@@ -21,9 +21,11 @@ namespace platform {
|
||||
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
|
||||
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
|
||||
std::vector<double> scores_train, scores_test, times_train, times_test;
|
||||
std::vector<std::string> notes;
|
||||
public:
|
||||
Result() = default;
|
||||
Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; }
|
||||
Result& setNotes(const std::vector<std::string>& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; }
|
||||
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
||||
Result& setSamples(int samples) { this->samples = samples; return *this; }
|
||||
Result& setFeatures(int features) { this->features = features; return *this; }
|
||||
@@ -61,6 +63,7 @@ namespace platform {
|
||||
const float getNodes() const { return nodes; }
|
||||
const float getLeaves() const { return leaves; }
|
||||
const float getDepth() const { return depth; }
|
||||
const std::vector<std::string>& getNotes() const { return notes; }
|
||||
const std::vector<double>& getScoresTrain() const { return scores_train; }
|
||||
const std::vector<double>& getScoresTest() const { return scores_test; }
|
||||
const std::vector<double>& getTimesTrain() const { return times_train; }
|
||||
|
@@ -57,7 +57,11 @@ namespace platform {
|
||||
}
|
||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||
std::cout << color;
|
||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||
std::string separator{ " " };
|
||||
if (r.find("notes") != r.end()) {
|
||||
separator = r["notes"].size() > 0 ? Colors::YELLOW() + Symbols::notebook + color : " ";
|
||||
}
|
||||
std::cout << std::setw(3) << std::right << index++ << separator;
|
||||
std::cout << std::setw(maxDataset) << std::left << r["dataset"].get<std::string>() << " ";
|
||||
std::cout << std::setw(6) << std::right << r["samples"].get<int>() << " ";
|
||||
std::cout << std::setw(5) << std::right << r["features"].get<int>() << " ";
|
||||
@@ -78,6 +82,14 @@ namespace platform {
|
||||
}
|
||||
if (data["results"].size() == 1 || selectedIndex != -1) {
|
||||
std::cout << std::string(MAXL, '*') << std::endl;
|
||||
if (lastResult.find("notes") != lastResult.end()) {
|
||||
if (lastResult["notes"].size() > 0) {
|
||||
std::cout << headerLine("Notes: ");
|
||||
for (const auto& note : lastResult["notes"]) {
|
||||
std::cout << headerLine(note.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12));
|
||||
std::cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12));
|
||||
std::cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));
|
||||
|
@@ -133,6 +133,16 @@ namespace platform {
|
||||
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
||||
// Show totals if only one dataset is present in the result
|
||||
if (data["results"].size() == 1) {
|
||||
row++;
|
||||
if (lastResult.find("notes") != lastResult.end()) {
|
||||
if (lastResult["notes"].size() > 0) {
|
||||
writeString(row++, 1, "Notes: ", "bodyHeader");
|
||||
for (const auto& note : lastResult["notes"]) {
|
||||
worksheet_merge_range(worksheet, row, 2, row, 5, note.get<std::string>().c_str(), efectiveStyle("text"));
|
||||
row++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const std::string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||
row++;
|
||||
col = 1;
|
||||
|
@@ -12,6 +12,7 @@ namespace platform {
|
||||
inline static const std::string down_arrow{ "\u27B4" };
|
||||
inline static const std::string equal_best{ check_mark };
|
||||
inline static const std::string better_best{ black_star };
|
||||
inline static const std::string notebook{ "\U0001F5C8" };
|
||||
};
|
||||
}
|
||||
#endif // !SYMBOLS_H
|
Reference in New Issue
Block a user