Git add Confusion Matrix to console report
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <cctype> // std::isdigit
|
||||
#include <algorithm> // std::all_of
|
||||
#include <iostream>
|
||||
|
||||
ArffFiles::ArffFiles() = default;
|
||||
@@ -162,7 +164,11 @@ std::vector<int> ArffFiles::factorize(const std::vector<std::string>& labels_t)
|
||||
for (const std::string& label : labels_t) {
|
||||
if (labelMap.find(label) == labelMap.end()) {
|
||||
labelMap[label] = i++;
|
||||
labels.push_back(label);
|
||||
bool allDigits = std::all_of(label.begin(), label.end(), isdigit);
|
||||
if (allDigits)
|
||||
labels.push_back("Class " + label);
|
||||
else
|
||||
labels.push_back(label);
|
||||
}
|
||||
yy.push_back(labelMap[label]);
|
||||
}
|
||||
|
@@ -160,6 +160,16 @@ namespace platform {
|
||||
recall_avg /= num_classes;
|
||||
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
|
||||
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
|
||||
oss << std::endl << "Confusion Matrix" << std::endl;
|
||||
oss << "================" << std::endl;
|
||||
auto number = total > 1000 ? 4 : 3;
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
oss << std::right << std::setw(label_len) << labels[i] << " ";
|
||||
for (int j = 0; j < num_classes; j++) {
|
||||
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
|
||||
}
|
||||
oss << std::endl;
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
||||
|
Reference in New Issue
Block a user