diff --git a/lib/argparse b/lib/argparse index 9550b0a..f0759fd 160000 --- a/lib/argparse +++ b/lib/argparse @@ -1 +1 @@ -Subproject commit 9550b0a88c85120a0bf456af935eed2956c73340 +Subproject commit f0759fd982bb4a88785094626ea522cb3a84ec84 diff --git a/lib/catch2 b/lib/catch2 index 53ddf37..4e8d92b 160000 --- a/lib/catch2 +++ b/lib/catch2 @@ -1 +1 @@ -Subproject commit 53ddf37af4488cac7724761858ae3cca9d2d65e7 +Subproject commit 4e8d92bf02f7d1c8006a0e7a5ecabd8e62d98502 diff --git a/lib/folding b/lib/folding index 71d6055..2ac43e3 160000 --- a/lib/folding +++ b/lib/folding @@ -1 +1 @@ -Subproject commit 71d6055be4488cf2e6443123ae8fc4a63ae289dc +Subproject commit 2ac43e32ac1eac0c986702ec526cf5367a565ef0 diff --git a/lib/json b/lib/json index c883fb0..8c391e0 160000 --- a/lib/json +++ b/lib/json @@ -1 +1 @@ -Subproject commit c883fb0f17cbdf75545bddcc551e21a924a31b05 +Subproject commit 8c391e04fe4195d8be862c97f38cfe10e2a3472e diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter index 7548faa..284b61b 160000 --- a/lib/libxlsxwriter +++ b/lib/libxlsxwriter @@ -1 +1 @@ -Subproject commit 7548faa95afdf8ac321136d10eda931683fbf7c6 +Subproject commit 284b61ba0b8930ad93003380defc4a0817b75079 diff --git a/lib/mdlp b/lib/mdlp index 5708dc3..236d1b2 160000 --- a/lib/mdlp +++ b/lib/mdlp @@ -1 +1 @@ -Subproject commit 5708dc3de944fc22d61a2dd071b63aa338e04db3 +Subproject commit 236d1b2f8be185039493fe7fce04a83e02ed72e5 diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index a735605..959280b 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -15,6 +15,34 @@ namespace platform { confusion_matrix[actual][predicted] += 1; } } + Scores::Scores(json& confusion_matrix_) + { + json values; + total = 0; + num_classes = confusion_matrix_.size(); + init_confusion_matrix(); + int i = 0; + for (const auto& item : confusion_matrix_.items()) { + values = item.value(); + json key = item.key(); + if (key.is_number_integer()) { + labels.push_back("Class " + std::to_string(key.get())); + } else { + labels.push_back(key.get()); + } + for (int j = 0; j < num_classes; ++j) { + int value_int = values[j].get(); + confusion_matrix[i][j] = value_int; + total += value_int; + } + i++; + } + // Compute accuracy with the confusion matrix + for (int i = 0; i < num_classes; i++) { + accuracy_value += confusion_matrix[i][i].item(); + } + accuracy_value /= total; + } void Scores::init_confusion_matrix() { confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32); @@ -34,35 +62,6 @@ namespace platform { accuracy_value += a.accuracy_value; accuracy_value /= 2; } - Scores::Scores(json& confusion_matrix_) - { - json values; - total = 0; - num_classes = confusion_matrix_.size(); - init_confusion_matrix(); - init_default_labels(); - int i = 0; - for (const auto& item : confusion_matrix_) { - if (item.is_array()) { - values = item; - } else { - auto it = item.begin(); - values = it.value(); - labels.push_back(it.key()); - } - for (int j = 0; j < num_classes; ++j) { - int value_int = values[j].get(); - confusion_matrix[i][j] = value_int; - total += value_int; - } - i++; - } - // Compute accuracy with the confusion matrix - for (int i = 0; i < num_classes; i++) { - accuracy_value += confusion_matrix[i][i].item(); - } - accuracy_value /= total; - } float Scores::accuracy() { return accuracy_value; @@ -125,6 +124,9 @@ namespace platform { std::string Scores::classification_report() { std::stringstream oss; + for (int i = 0; i < num_classes; i++) { + label_len = std::max(label_len, (int)labels[i].size()); + } oss << "Classification Report" << std::endl; oss << "=====================" << std::endl; oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;