Add classifier notes to console & excel report
This commit is contained in:
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -13,6 +13,8 @@
|
|||||||
[submodule "lib/mdlp"]
|
[submodule "lib/mdlp"]
|
||||||
path = lib/mdlp
|
path = lib/mdlp
|
||||||
url = https://github.com/rmontanana/mdlp
|
url = https://github.com/rmontanana/mdlp
|
||||||
|
update = merge
|
||||||
[submodule "lib/PyClassifiers"]
|
[submodule "lib/PyClassifiers"]
|
||||||
path = lib/PyClassifiers
|
path = lib/PyClassifiers
|
||||||
url = https://github.com/rmontanana/PyClassifiers
|
url = https://github.com/rmontanana/PyClassifiers
|
||||||
|
update = merge
|
||||||
|
Submodule lib/PyClassifiers updated: f46f6dcbb2...3226fb671b
Submodule lib/argparse updated: 69dabd88a8...1b3abd9b92
@@ -76,6 +76,7 @@ namespace platform {
|
|||||||
j["nodes"] = r.getNodes();
|
j["nodes"] = r.getNodes();
|
||||||
j["leaves"] = r.getLeaves();
|
j["leaves"] = r.getLeaves();
|
||||||
j["depth"] = r.getDepth();
|
j["depth"] = r.getDepth();
|
||||||
|
j["notes"] = r.getNotes();
|
||||||
result["results"].push_back(j);
|
result["results"].push_back(j);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@@ -162,7 +163,7 @@ namespace platform {
|
|||||||
if (!quiet) {
|
if (!quiet) {
|
||||||
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
||||||
}
|
}
|
||||||
// Prepare Result
|
// Prepare Resu lt
|
||||||
auto result = Result();
|
auto result = Result();
|
||||||
auto [values, counts] = at::_unique(y);
|
auto [values, counts] = at::_unique(y);
|
||||||
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0));
|
||||||
@@ -176,6 +177,7 @@ namespace platform {
|
|||||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
|
std::vector<std::string> notes;
|
||||||
Timer train_timer, test_timer;
|
Timer train_timer, test_timer;
|
||||||
int item = 0;
|
int item = 0;
|
||||||
bool first_seed = true;
|
bool first_seed = true;
|
||||||
@@ -214,6 +216,8 @@ namespace platform {
|
|||||||
clf->fit(X_train, y_train, features, className, states);
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
auto clf_notes = clf->getNotes();
|
||||||
|
notes.insert(notes.end(), clf_notes.begin(), clf_notes.end());
|
||||||
nodes[item] = clf->getNumberOfNodes();
|
nodes[item] = clf->getNumberOfNodes();
|
||||||
edges[item] = clf->getNumberOfEdges();
|
edges[item] = clf->getNumberOfEdges();
|
||||||
num_states[item] = clf->getNumberOfStates();
|
num_states[item] = clf->getNumberOfStates();
|
||||||
@@ -248,7 +252,7 @@ namespace platform {
|
|||||||
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||||
result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||||
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||||
result.setDataset(fileName);
|
result.setDataset(fileName).setNotes(notes);
|
||||||
addResult(result);
|
addResult(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -21,9 +21,11 @@ namespace platform {
|
|||||||
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
|
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
|
||||||
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
|
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
|
||||||
std::vector<double> scores_train, scores_test, times_train, times_test;
|
std::vector<double> scores_train, scores_test, times_train, times_test;
|
||||||
|
std::vector<std::string> notes;
|
||||||
public:
|
public:
|
||||||
Result() = default;
|
Result() = default;
|
||||||
Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; }
|
Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; }
|
||||||
|
Result& setNotes(const std::vector<std::string>& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; }
|
||||||
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
|
||||||
Result& setSamples(int samples) { this->samples = samples; return *this; }
|
Result& setSamples(int samples) { this->samples = samples; return *this; }
|
||||||
Result& setFeatures(int features) { this->features = features; return *this; }
|
Result& setFeatures(int features) { this->features = features; return *this; }
|
||||||
@@ -61,6 +63,7 @@ namespace platform {
|
|||||||
const float getNodes() const { return nodes; }
|
const float getNodes() const { return nodes; }
|
||||||
const float getLeaves() const { return leaves; }
|
const float getLeaves() const { return leaves; }
|
||||||
const float getDepth() const { return depth; }
|
const float getDepth() const { return depth; }
|
||||||
|
const std::vector<std::string>& getNotes() const { return notes; }
|
||||||
const std::vector<double>& getScoresTrain() const { return scores_train; }
|
const std::vector<double>& getScoresTrain() const { return scores_train; }
|
||||||
const std::vector<double>& getScoresTest() const { return scores_test; }
|
const std::vector<double>& getScoresTest() const { return scores_test; }
|
||||||
const std::vector<double>& getTimesTrain() const { return times_train; }
|
const std::vector<double>& getTimesTrain() const { return times_train; }
|
||||||
|
@@ -57,7 +57,11 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
std::cout << color;
|
std::cout << color;
|
||||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
std::string separator{ " " };
|
||||||
|
if (r.find("notes") != r.end()) {
|
||||||
|
separator = r["notes"].size() > 0 ? Colors::YELLOW() + Symbols::notebook + color : " ";
|
||||||
|
}
|
||||||
|
std::cout << std::setw(3) << std::right << index++ << separator;
|
||||||
std::cout << std::setw(maxDataset) << std::left << r["dataset"].get<std::string>() << " ";
|
std::cout << std::setw(maxDataset) << std::left << r["dataset"].get<std::string>() << " ";
|
||||||
std::cout << std::setw(6) << std::right << r["samples"].get<int>() << " ";
|
std::cout << std::setw(6) << std::right << r["samples"].get<int>() << " ";
|
||||||
std::cout << std::setw(5) << std::right << r["features"].get<int>() << " ";
|
std::cout << std::setw(5) << std::right << r["features"].get<int>() << " ";
|
||||||
@@ -78,6 +82,14 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
if (data["results"].size() == 1 || selectedIndex != -1) {
|
if (data["results"].size() == 1 || selectedIndex != -1) {
|
||||||
std::cout << std::string(MAXL, '*') << std::endl;
|
std::cout << std::string(MAXL, '*') << std::endl;
|
||||||
|
if (lastResult.find("notes") != lastResult.end()) {
|
||||||
|
if (lastResult["notes"].size() > 0) {
|
||||||
|
std::cout << headerLine("Notes: ");
|
||||||
|
for (const auto& note : lastResult["notes"]) {
|
||||||
|
std::cout << headerLine(note.get<std::string>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
std::cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12));
|
std::cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12));
|
||||||
std::cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12));
|
std::cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12));
|
||||||
std::cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));
|
std::cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3));
|
||||||
|
@@ -133,6 +133,16 @@ namespace platform {
|
|||||||
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
||||||
// Show totals if only one dataset is present in the result
|
// Show totals if only one dataset is present in the result
|
||||||
if (data["results"].size() == 1) {
|
if (data["results"].size() == 1) {
|
||||||
|
row++;
|
||||||
|
if (lastResult.find("notes") != lastResult.end()) {
|
||||||
|
if (lastResult["notes"].size() > 0) {
|
||||||
|
writeString(row++, 1, "Notes: ", "bodyHeader");
|
||||||
|
for (const auto& note : lastResult["notes"]) {
|
||||||
|
worksheet_merge_range(worksheet, row, 2, row, 5, note.get<std::string>().c_str(), efectiveStyle("text"));
|
||||||
|
row++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (const std::string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
for (const std::string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||||
row++;
|
row++;
|
||||||
col = 1;
|
col = 1;
|
||||||
|
@@ -12,6 +12,7 @@ namespace platform {
|
|||||||
inline static const std::string down_arrow{ "\u27B4" };
|
inline static const std::string down_arrow{ "\u27B4" };
|
||||||
inline static const std::string equal_best{ check_mark };
|
inline static const std::string equal_best{ check_mark };
|
||||||
inline static const std::string better_best{ black_star };
|
inline static const std::string better_best{ black_star };
|
||||||
|
inline static const std::string notebook{ "\U0001F5C8" };
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !SYMBOLS_H
|
#endif // !SYMBOLS_H
|
Reference in New Issue
Block a user