Refactor aggregate score to a constructor
This commit is contained in:
@@ -40,6 +40,15 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
compute_accuracy_value();
|
compute_accuracy_value();
|
||||||
}
|
}
|
||||||
|
static Score Scores::create_aggregate(json& data, std::string key)
|
||||||
|
{
|
||||||
|
auto scores = Scores(result[key][0]);
|
||||||
|
for (int i = 1; i < result[key].size(); i++) {
|
||||||
|
auto score = Scores(result[key][i]);
|
||||||
|
scores.aggregate(score);
|
||||||
|
}
|
||||||
|
return scores;
|
||||||
|
}
|
||||||
void Scores::compute_accuracy_value()
|
void Scores::compute_accuracy_value()
|
||||||
{
|
{
|
||||||
accuracy_value = 0;
|
accuracy_value = 0;
|
||||||
|
@@ -10,6 +10,7 @@ namespace platform {
|
|||||||
public:
|
public:
|
||||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||||
explicit Scores(json& confusion_matrix_);
|
explicit Scores(json& confusion_matrix_);
|
||||||
|
static Score create_aggregate(json& data, std::string key);
|
||||||
float accuracy();
|
float accuracy();
|
||||||
float f1_score(int num_class);
|
float f1_score(int num_class);
|
||||||
float f1_weighted();
|
float f1_weighted();
|
||||||
|
@@ -186,13 +186,13 @@ namespace platform {
|
|||||||
int lines_header = 0;
|
int lines_header = 0;
|
||||||
std::string color_line;
|
std::string color_line;
|
||||||
std::string suffix = "";
|
std::string suffix = "";
|
||||||
auto scores = aggregateScore(result, "confusion_matrices");
|
auto scores = Scores::create_aggregate(result, "confusion_matrices");
|
||||||
auto output_test = scores.classification_report(color, "Test");
|
auto output_test = scores.classification_report(color, "Test");
|
||||||
int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
||||||
bool train_data = result.find("confusion_matrices_train") != result.end();
|
bool train_data = result.find("confusion_matrices_train") != result.end();
|
||||||
std::vector<std::string> output_train;
|
std::vector<std::string> output_train;
|
||||||
if (train_data) {
|
if (train_data) {
|
||||||
auto scores_train = aggregateScore(result, "confusion_matrices_train");
|
auto scores_train = Scores::create_aggregate(result, "confusion_matrices_train");
|
||||||
output_train = scores_train.classification_report(color, "Train");
|
output_train = scores_train.classification_report(color, "Train");
|
||||||
}
|
}
|
||||||
oss << Colors::BLUE();
|
oss << Colors::BLUE();
|
||||||
|
Reference in New Issue
Block a user