Update libraries
This commit is contained in:
Submodule lib/argparse updated: 9550b0a88c...f0759fd982
Submodule lib/catch2 updated: 53ddf37af4...4e8d92bf02
Submodule lib/folding updated: 71d6055be4...2ac43e32ac
2
lib/json
2
lib/json
Submodule lib/json updated: c883fb0f17...8c391e04fe
Submodule lib/libxlsxwriter updated: 7548faa95a...284b61ba0b
2
lib/mdlp
2
lib/mdlp
Submodule lib/mdlp updated: 5708dc3de9...236d1b2f8b
@@ -15,6 +15,34 @@ namespace platform {
|
|||||||
confusion_matrix[actual][predicted] += 1;
|
confusion_matrix[actual][predicted] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Scores::Scores(json& confusion_matrix_)
|
||||||
|
{
|
||||||
|
json values;
|
||||||
|
total = 0;
|
||||||
|
num_classes = confusion_matrix_.size();
|
||||||
|
init_confusion_matrix();
|
||||||
|
int i = 0;
|
||||||
|
for (const auto& item : confusion_matrix_.items()) {
|
||||||
|
values = item.value();
|
||||||
|
json key = item.key();
|
||||||
|
if (key.is_number_integer()) {
|
||||||
|
labels.push_back("Class " + std::to_string(key.get<int>()));
|
||||||
|
} else {
|
||||||
|
labels.push_back(key.get<std::string>());
|
||||||
|
}
|
||||||
|
for (int j = 0; j < num_classes; ++j) {
|
||||||
|
int value_int = values[j].get<int>();
|
||||||
|
confusion_matrix[i][j] = value_int;
|
||||||
|
total += value_int;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
// Compute accuracy with the confusion matrix
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||||
|
}
|
||||||
|
accuracy_value /= total;
|
||||||
|
}
|
||||||
void Scores::init_confusion_matrix()
|
void Scores::init_confusion_matrix()
|
||||||
{
|
{
|
||||||
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||||
@@ -34,35 +62,6 @@ namespace platform {
|
|||||||
accuracy_value += a.accuracy_value;
|
accuracy_value += a.accuracy_value;
|
||||||
accuracy_value /= 2;
|
accuracy_value /= 2;
|
||||||
}
|
}
|
||||||
Scores::Scores(json& confusion_matrix_)
|
|
||||||
{
|
|
||||||
json values;
|
|
||||||
total = 0;
|
|
||||||
num_classes = confusion_matrix_.size();
|
|
||||||
init_confusion_matrix();
|
|
||||||
init_default_labels();
|
|
||||||
int i = 0;
|
|
||||||
for (const auto& item : confusion_matrix_) {
|
|
||||||
if (item.is_array()) {
|
|
||||||
values = item;
|
|
||||||
} else {
|
|
||||||
auto it = item.begin();
|
|
||||||
values = it.value();
|
|
||||||
labels.push_back(it.key());
|
|
||||||
}
|
|
||||||
for (int j = 0; j < num_classes; ++j) {
|
|
||||||
int value_int = values[j].get<int>();
|
|
||||||
confusion_matrix[i][j] = value_int;
|
|
||||||
total += value_int;
|
|
||||||
}
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
// Compute accuracy with the confusion matrix
|
|
||||||
for (int i = 0; i < num_classes; i++) {
|
|
||||||
accuracy_value += confusion_matrix[i][i].item<int>();
|
|
||||||
}
|
|
||||||
accuracy_value /= total;
|
|
||||||
}
|
|
||||||
float Scores::accuracy()
|
float Scores::accuracy()
|
||||||
{
|
{
|
||||||
return accuracy_value;
|
return accuracy_value;
|
||||||
@@ -125,6 +124,9 @@ namespace platform {
|
|||||||
std::string Scores::classification_report()
|
std::string Scores::classification_report()
|
||||||
{
|
{
|
||||||
std::stringstream oss;
|
std::stringstream oss;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
label_len = std::max(label_len, (int)labels[i].size());
|
||||||
|
}
|
||||||
oss << "Classification Report" << std::endl;
|
oss << "Classification Report" << std::endl;
|
||||||
oss << "=====================" << std::endl;
|
oss << "=====================" << std::endl;
|
||||||
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
|
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
|
||||||
|
Reference in New Issue
Block a user