Add train classification report
This commit is contained in:
@@ -102,6 +102,7 @@ namespace platform {
|
|||||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||||
json confusion_matrices = json::array();
|
json confusion_matrices = json::array();
|
||||||
|
json confusion_matrices_train = json::array();
|
||||||
std::vector<std::string> notes;
|
std::vector<std::string> notes;
|
||||||
Timer train_timer, test_timer;
|
Timer train_timer, test_timer;
|
||||||
int item = 0;
|
int item = 0;
|
||||||
@@ -150,8 +151,12 @@ namespace platform {
|
|||||||
train_time[item] = train_timer.getDuration();
|
train_time[item] = train_timer.getDuration();
|
||||||
double accuracy_train_value = 0.0;
|
double accuracy_train_value = 0.0;
|
||||||
// Score train
|
// Score train
|
||||||
if (!no_train_score)
|
if (!no_train_score) {
|
||||||
accuracy_train_value = clf->score(X_train, y_train);
|
auto y_predict = clf->predict(X_train);
|
||||||
|
Scores scores(y_train, y_predict, states[className].size(), labels);
|
||||||
|
accuracy_train_value = scores.accuracy();
|
||||||
|
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||||
|
}
|
||||||
// Test model
|
// Test model
|
||||||
if (!quiet)
|
if (!quiet)
|
||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
@@ -183,6 +188,8 @@ namespace platform {
|
|||||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||||
partial_result.setDataset(fileName).setNotes(notes);
|
partial_result.setDataset(fileName).setNotes(notes);
|
||||||
partial_result.setConfusionMatrices(confusion_matrices);
|
partial_result.setConfusionMatrices(confusion_matrices);
|
||||||
|
if (!no_train_score)
|
||||||
|
partial_result.setConfusionMatricesTrain(confusion_matrices_train);
|
||||||
addResult(partial_result);
|
addResult(partial_result);
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -28,6 +28,7 @@ namespace platform {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; }
|
PartialResult& setConfusionMatrices(const json& confusion_matrices) { data["confusion_matrices"] = confusion_matrices; return *this; }
|
||||||
|
PartialResult& setConfusionMatricesTrain(const json& confusion_matrices) { data["confusion_matrices_train"] = confusion_matrices; return *this; }
|
||||||
PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; }
|
PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; }
|
||||||
PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; }
|
PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; }
|
||||||
PartialResult& setFeatures(int features) { data["features"] = features; return *this; }
|
PartialResult& setFeatures(int features) { data["features"] = features; return *this; }
|
||||||
|
@@ -126,24 +126,28 @@ namespace platform {
|
|||||||
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " ";
|
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " ";
|
||||||
}
|
}
|
||||||
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " "
|
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " "
|
||||||
<< std::setw(dlen) << std::right << support << std::endl;
|
<< std::setw(dlen) << std::right << support;
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
std::string Scores::classification_report(std::string color)
|
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
|
||||||
{
|
{
|
||||||
std::stringstream oss;
|
std::stringstream oss;
|
||||||
|
std::vector<std::string> report;
|
||||||
for (int i = 0; i < num_classes; i++) {
|
for (int i = 0; i < num_classes; i++) {
|
||||||
label_len = std::max(label_len, (int)labels[i].size());
|
label_len = std::max(label_len, (int)labels[i].size());
|
||||||
}
|
}
|
||||||
oss << Colors::GREEN() << "Classification Report" << std::endl;
|
report.push_back("Classification Report using " + title + " dataset");
|
||||||
oss << "=====================" << std::endl << color;
|
report.push_back("=========================================");
|
||||||
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
|
oss << std::string(label_len, ' ') << " precision recall f1-score support";
|
||||||
oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl;
|
report.push_back(oss.str()); oss.str("");
|
||||||
|
oss << std::string(label_len, ' ') << " ========= ========= ========= =========";
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
for (int i = 0; i < num_classes; i++) {
|
for (int i = 0; i < num_classes; i++) {
|
||||||
oss << classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>());
|
report.push_back(classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>()));
|
||||||
}
|
}
|
||||||
oss << std::endl;
|
report.push_back(" ");
|
||||||
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
|
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
|
||||||
|
report.push_back(oss.str()); oss.str("");
|
||||||
float precision_avg = 0;
|
float precision_avg = 0;
|
||||||
float recall_avg = 0;
|
float recall_avg = 0;
|
||||||
float precision_wavg = 0;
|
float precision_wavg = 0;
|
||||||
@@ -159,10 +163,11 @@ namespace platform {
|
|||||||
recall_wavg /= total;
|
recall_wavg /= total;
|
||||||
precision_avg /= num_classes;
|
precision_avg /= num_classes;
|
||||||
recall_avg /= num_classes;
|
recall_avg /= num_classes;
|
||||||
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
|
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
|
||||||
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
|
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
|
||||||
oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl;
|
report.push_back("");
|
||||||
oss << "================" << std::endl << color;
|
report.push_back("Confusion Matrix");
|
||||||
|
report.push_back("================");
|
||||||
auto number = total > 1000 ? 4 : 3;
|
auto number = total > 1000 ? 4 : 3;
|
||||||
for (int i = 0; i < num_classes; i++) {
|
for (int i = 0; i < num_classes; i++) {
|
||||||
oss << std::right << std::setw(label_len) << labels[i] << " ";
|
oss << std::right << std::setw(label_len) << labels[i] << " ";
|
||||||
@@ -171,10 +176,9 @@ namespace platform {
|
|||||||
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
|
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
|
||||||
if (i == j) oss << color;
|
if (i == j) oss << color;
|
||||||
}
|
}
|
||||||
oss << std::endl;
|
report.push_back(oss.str()); oss.str("");
|
||||||
}
|
}
|
||||||
oss << Colors::RESET();
|
return report;
|
||||||
return oss.str();
|
|
||||||
}
|
}
|
||||||
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
||||||
{
|
{
|
||||||
|
@@ -17,7 +17,7 @@ namespace platform {
|
|||||||
float precision(int num_class);
|
float precision(int num_class);
|
||||||
float recall(int num_class);
|
float recall(int num_class);
|
||||||
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||||
std::string classification_report(std::string color = "");
|
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
|
||||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||||
void aggregate(const Scores& a);
|
void aggregate(const Scores& a);
|
||||||
private:
|
private:
|
||||||
@@ -30,7 +30,7 @@ namespace platform {
|
|||||||
int total;
|
int total;
|
||||||
std::vector<std::string> labels;
|
std::vector<std::string> labels;
|
||||||
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
|
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
|
||||||
int label_len = 12;
|
int label_len = 16;
|
||||||
int dlen = 9;
|
int dlen = 9;
|
||||||
int ndec = 7;
|
int ndec = 7;
|
||||||
};
|
};
|
||||||
|
@@ -3,7 +3,6 @@
|
|||||||
#include "best/BestScore.h"
|
#include "best/BestScore.h"
|
||||||
#include "common/CLocale.h"
|
#include "common/CLocale.h"
|
||||||
#include "ReportConsole.h"
|
#include "ReportConsole.h"
|
||||||
#include "main/Scores.h"
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
std::string ReportConsole::headerLine(const std::string& text, int utf = 0)
|
std::string ReportConsole::headerLine(const std::string& text, int utf = 0)
|
||||||
@@ -135,7 +134,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl;
|
sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl;
|
||||||
vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n");
|
vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n");
|
||||||
if (lastResult.find("confusion_matrices") != lastResult.end() && (data["results"].size() == 1 || selectedIndex != -1)) {
|
if (data["results"].size() == 1 || selectedIndex != -1) {
|
||||||
vbody.push_back(showClassificationReport(Colors::BLUE()));
|
vbody.push_back(showClassificationReport(Colors::BLUE()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,17 +168,60 @@ namespace platform {
|
|||||||
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
|
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::string ReportConsole::showClassificationReport(std::string color)
|
Scores ReportConsole::aggregateScore(std::string key)
|
||||||
{
|
{
|
||||||
auto lastResult = data["results"][0];
|
auto lastResult = data["results"][0];
|
||||||
if (data["results"].size() > 1 || lastResult.find("confusion_matrices") == lastResult.end())
|
|
||||||
return "";
|
|
||||||
auto item = data["results"][0];
|
auto item = data["results"][0];
|
||||||
auto scores = Scores(item["confusion_matrices"][0]);
|
auto scores = Scores(item[key][0]);
|
||||||
for (int i = 1; i < item["confusion_matrices"].size(); i++) {
|
for (int i = 1; i < item[key].size(); i++) {
|
||||||
auto score = Scores(item["confusion_matrices"][i]);
|
auto score = Scores(item[key][i]);
|
||||||
scores.aggregate(score);
|
scores.aggregate(score);
|
||||||
}
|
}
|
||||||
return scores.classification_report(color);
|
return scores;
|
||||||
|
}
|
||||||
|
std::string ReportConsole::showClassificationReport(std::string color)
|
||||||
|
{
|
||||||
|
std::stringstream oss;
|
||||||
|
auto result = data["results"][0];
|
||||||
|
if (result.find("confusion_matrices") == result.end())
|
||||||
|
return "";
|
||||||
|
auto scores = aggregateScore("confusion_matrices");
|
||||||
|
auto output_test = scores.classification_report(color, "Test");
|
||||||
|
oss << Colors::BLUE();
|
||||||
|
if (result.find("confusion_matrices_train") == result.end()) {
|
||||||
|
for (auto& line : output_test) {
|
||||||
|
|
||||||
|
oss << line << std::endl;
|
||||||
|
}
|
||||||
|
oss << Colors::RESET();
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
auto scores_train = aggregateScore("confusion_matrices_train");
|
||||||
|
auto output_train = scores_train.classification_report(color, "Train");
|
||||||
|
int maxLine = (*std::max_element(output_train.begin(), output_train.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
||||||
|
bool second_header = false;
|
||||||
|
int lines_header = 0;
|
||||||
|
std::string color_line;
|
||||||
|
std::string suffix = "";
|
||||||
|
for (int i = 0; i < output_train.size(); i++) {
|
||||||
|
if (i < 2 || second_header) {
|
||||||
|
color_line = Colors::GREEN();
|
||||||
|
} else {
|
||||||
|
color_line = Colors::BLUE();
|
||||||
|
if (lines_header > 1)
|
||||||
|
suffix = std::string(14, ' '); // compensate for the color
|
||||||
|
}
|
||||||
|
oss << color_line << std::left << std::setw(maxLine) << output_train[i]
|
||||||
|
<< suffix << Colors::BLUE() << " | " << color_line << std::left << std::setw(maxLine)
|
||||||
|
<< output_test[i] << std::endl;
|
||||||
|
if (output_train[i] == "" || (second_header && lines_header < 2)) {
|
||||||
|
lines_header++;
|
||||||
|
second_header = true;
|
||||||
|
} else {
|
||||||
|
second_header = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
oss << Colors::RESET();
|
||||||
|
return oss.str();
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -4,6 +4,8 @@
|
|||||||
#include "common/Colors.h"
|
#include "common/Colors.h"
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "ReportBase.h"
|
#include "ReportBase.h"
|
||||||
|
#include "main/Scores.h"
|
||||||
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
const int MAXL = 133;
|
const int MAXL = 133;
|
||||||
@@ -24,6 +26,7 @@ namespace platform {
|
|||||||
void do_body();
|
void do_body();
|
||||||
void footer(double totalScore);
|
void footer(double totalScore);
|
||||||
void showSummary() override;
|
void showSummary() override;
|
||||||
|
Scores aggregateScore(std::string key);
|
||||||
std::stringstream sheader;
|
std::stringstream sheader;
|
||||||
std::stringstream sbody;
|
std::stringstream sbody;
|
||||||
std::vector<std::string> vbody;
|
std::vector<std::string> vbody;
|
||||||
|
Reference in New Issue
Block a user