Update libraries
This commit is contained in:
Submodule lib/argparse updated: 9550b0a88c...f0759fd982
Submodule lib/catch2 updated: 53ddf37af4...4e8d92bf02
Submodule lib/folding updated: 71d6055be4...2ac43e32ac
2
lib/json
2
lib/json
Submodule lib/json updated: c883fb0f17...8c391e04fe
Submodule lib/libxlsxwriter updated: 7548faa95a...284b61ba0b
2
lib/mdlp
2
lib/mdlp
Submodule lib/mdlp updated: 5708dc3de9...236d1b2f8b
@@ -15,6 +15,34 @@ namespace platform {
|
||||
confusion_matrix[actual][predicted] += 1;
|
||||
}
|
||||
}
|
||||
Scores::Scores(json& confusion_matrix_)
|
||||
{
|
||||
json values;
|
||||
total = 0;
|
||||
num_classes = confusion_matrix_.size();
|
||||
init_confusion_matrix();
|
||||
int i = 0;
|
||||
for (const auto& item : confusion_matrix_.items()) {
|
||||
values = item.value();
|
||||
json key = item.key();
|
||||
if (key.is_number_integer()) {
|
||||
labels.push_back("Class " + std::to_string(key.get<int>()));
|
||||
} else {
|
||||
labels.push_back(key.get<std::string>());
|
||||
}
|
||||
for (int j = 0; j < num_classes; ++j) {
|
||||
int value_int = values[j].get<int>();
|
||||
confusion_matrix[i][j] = value_int;
|
||||
total += value_int;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
// Compute accuracy with the confusion matrix
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||
}
|
||||
accuracy_value /= total;
|
||||
}
|
||||
void Scores::init_confusion_matrix()
|
||||
{
|
||||
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||
@@ -34,35 +62,6 @@ namespace platform {
|
||||
accuracy_value += a.accuracy_value;
|
||||
accuracy_value /= 2;
|
||||
}
|
||||
Scores::Scores(json& confusion_matrix_)
|
||||
{
|
||||
json values;
|
||||
total = 0;
|
||||
num_classes = confusion_matrix_.size();
|
||||
init_confusion_matrix();
|
||||
init_default_labels();
|
||||
int i = 0;
|
||||
for (const auto& item : confusion_matrix_) {
|
||||
if (item.is_array()) {
|
||||
values = item;
|
||||
} else {
|
||||
auto it = item.begin();
|
||||
values = it.value();
|
||||
labels.push_back(it.key());
|
||||
}
|
||||
for (int j = 0; j < num_classes; ++j) {
|
||||
int value_int = values[j].get<int>();
|
||||
confusion_matrix[i][j] = value_int;
|
||||
total += value_int;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
// Compute accuracy with the confusion matrix
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||
}
|
||||
accuracy_value /= total;
|
||||
}
|
||||
float Scores::accuracy()
|
||||
{
|
||||
return accuracy_value;
|
||||
@@ -125,6 +124,9 @@ namespace platform {
|
||||
std::string Scores::classification_report()
|
||||
{
|
||||
std::stringstream oss;
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
label_len = std::max(label_len, (int)labels[i].size());
|
||||
}
|
||||
oss << "Classification Report" << std::endl;
|
||||
oss << "=====================" << std::endl;
|
||||
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
|
||||
|
Reference in New Issue
Block a user