Refactor aggregate score to a constructor
This commit is contained in:
@@ -40,6 +40,15 @@ namespace platform {
|
||||
}
|
||||
compute_accuracy_value();
|
||||
}
|
||||
static Score Scores::create_aggregate(json& data, std::string key)
|
||||
{
|
||||
auto scores = Scores(result[key][0]);
|
||||
for (int i = 1; i < result[key].size(); i++) {
|
||||
auto score = Scores(result[key][i]);
|
||||
scores.aggregate(score);
|
||||
}
|
||||
return scores;
|
||||
}
|
||||
void Scores::compute_accuracy_value()
|
||||
{
|
||||
accuracy_value = 0;
|
||||
|
@@ -10,6 +10,7 @@ namespace platform {
|
||||
public:
|
||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||
explicit Scores(json& confusion_matrix_);
|
||||
static Score create_aggregate(json& data, std::string key);
|
||||
float accuracy();
|
||||
float f1_score(int num_class);
|
||||
float f1_weighted();
|
||||
|
@@ -186,13 +186,13 @@ namespace platform {
|
||||
int lines_header = 0;
|
||||
std::string color_line;
|
||||
std::string suffix = "";
|
||||
auto scores = aggregateScore(result, "confusion_matrices");
|
||||
auto scores = Scores::create_aggregate(result, "confusion_matrices");
|
||||
auto output_test = scores.classification_report(color, "Test");
|
||||
int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
|
||||
bool train_data = result.find("confusion_matrices_train") != result.end();
|
||||
std::vector<std::string> output_train;
|
||||
if (train_data) {
|
||||
auto scores_train = aggregateScore(result, "confusion_matrices_train");
|
||||
auto scores_train = Scores::create_aggregate(result, "confusion_matrices_train");
|
||||
output_train = scores_train.classification_report(color, "Train");
|
||||
}
|
||||
oss << Colors::BLUE();
|
||||
|
Reference in New Issue
Block a user