Update libraries

This commit is contained in:
2024-05-12 12:26:49 +02:00
parent 69b9609154
commit f88b223c46
7 changed files with 37 additions and 35 deletions

View File

@@ -15,6 +15,34 @@ namespace platform {
confusion_matrix[actual][predicted] += 1;
}
}
Scores::Scores(json& confusion_matrix_)
{
json values;
total = 0;
num_classes = confusion_matrix_.size();
init_confusion_matrix();
int i = 0;
for (const auto& item : confusion_matrix_.items()) {
values = item.value();
json key = item.key();
if (key.is_number_integer()) {
labels.push_back("Class " + std::to_string(key.get<int>()));
} else {
labels.push_back(key.get<std::string>());
}
for (int j = 0; j < num_classes; ++j) {
int value_int = values[j].get<int>();
confusion_matrix[i][j] = value_int;
total += value_int;
}
i++;
}
// Compute accuracy with the confusion matrix
for (int i = 0; i < num_classes; i++) {
accuracy_value += confusion_matrix[i][i].item<int>();
}
accuracy_value /= total;
}
void Scores::init_confusion_matrix()
{
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
@@ -34,35 +62,6 @@ namespace platform {
accuracy_value += a.accuracy_value;
accuracy_value /= 2;
}
Scores::Scores(json& confusion_matrix_)
{
json values;
total = 0;
num_classes = confusion_matrix_.size();
init_confusion_matrix();
init_default_labels();
int i = 0;
for (const auto& item : confusion_matrix_) {
if (item.is_array()) {
values = item;
} else {
auto it = item.begin();
values = it.value();
labels.push_back(it.key());
}
for (int j = 0; j < num_classes; ++j) {
int value_int = values[j].get<int>();
confusion_matrix[i][j] = value_int;
total += value_int;
}
i++;
}
// Compute accuracy with the confusion matrix
for (int i = 0; i < num_classes; i++) {
accuracy_value += confusion_matrix[i][i].item<int>();
}
accuracy_value /= total;
}
float Scores::accuracy()
{
return accuracy_value;
@@ -125,6 +124,9 @@ namespace platform {
std::string Scores::classification_report()
{
std::stringstream oss;
for (int i = 0; i < num_classes; i++) {
label_len = std::max(label_len, (int)labels[i].size());
}
oss << "Classification Report" << std::endl;
oss << "=====================" << std::endl;
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;