Add json constructor to Scores
This commit is contained in:
@@ -4,19 +4,57 @@ namespace platform {
|
||||
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
|
||||
{
|
||||
if (labels.size() == 0) {
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
this->labels.push_back("Class " + std::to_string(i));
|
||||
}
|
||||
init_default_labels();
|
||||
}
|
||||
total = y_test.size(0);
|
||||
accuracy_value = (y_pred == y_test).sum().item<float>() / total;
|
||||
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||
init_confusion_matrix();
|
||||
for (int i = 0; i < total; i++) {
|
||||
int actual = y_test[i].item<int>();
|
||||
int predicted = y_pred[i].item<int>();
|
||||
confusion_matrix[actual][predicted] += 1;
|
||||
}
|
||||
}
|
||||
void Scores::init_confusion_matrix()
|
||||
{
|
||||
confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32);
|
||||
}
|
||||
void Scores::init_default_labels()
|
||||
{
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
labels.push_back("Class " + std::to_string(i));
|
||||
}
|
||||
}
|
||||
Scores::Scores(json& confusion_matrix_)
|
||||
{
|
||||
json values;
|
||||
total = 0;
|
||||
num_classes = confusion_matrix_.size();
|
||||
init_confusion_matrix();
|
||||
init_default_labels();
|
||||
int i = 0;
|
||||
for (const auto& item : confusion_matrix_) {
|
||||
if (item.is_array()) {
|
||||
values = item;
|
||||
} else {
|
||||
auto it = item.begin();
|
||||
values = it.value();
|
||||
labels.push_back(it.key());
|
||||
}
|
||||
for (int j = 0; j < num_classes; ++j) {
|
||||
int value_int = values[j].get<int>();
|
||||
confusion_matrix[i][j] = value_int;
|
||||
total += value_int;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
i++;
|
||||
}
|
||||
// Compute accuracy with the confusion matrix
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
accuracy_value += confusion_matrix[i][i].item<int>();
|
||||
}
|
||||
accuracy_value /= total;
|
||||
}
|
||||
float Scores::accuracy()
|
||||
{
|
||||
return accuracy_value;
|
||||
|
@@ -9,6 +9,7 @@ namespace platform {
|
||||
class Scores {
|
||||
public:
|
||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||
explicit Scores(json& confusion_matrix_);
|
||||
float accuracy();
|
||||
float f1_score(int num_class);
|
||||
float f1_weighted();
|
||||
@@ -20,6 +21,8 @@ namespace platform {
|
||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||
private:
|
||||
std::string classification_report_line(std::string label, float precision, float recall, float f1_score, int support);
|
||||
void init_confusion_matrix();
|
||||
void init_default_labels();
|
||||
int num_classes;
|
||||
float accuracy_value;
|
||||
int total;
|
||||
|
Reference in New Issue
Block a user