Add train classification report

This commit is contained in:
2024-05-14 11:45:54 +02:00
parent 99c9c6731f
commit 5c190d7c66
6 changed files with 85 additions and 28 deletions

View File

@@ -102,6 +102,7 @@ namespace platform {
auto edges = torch::zeros({ nResults }, torch::kFloat64);
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
json confusion_matrices = json::array();
json confusion_matrices_train = json::array();
std::vector<std::string> notes;
Timer train_timer, test_timer;
int item = 0;
@@ -150,8 +151,12 @@ namespace platform {
train_time[item] = train_timer.getDuration();
double accuracy_train_value = 0.0;
// Score train
if (!no_train_score)
accuracy_train_value = clf->score(X_train, y_train);
if (!no_train_score) {
auto y_predict = clf->predict(X_train);
Scores scores(y_train, y_predict, states[className].size(), labels);
accuracy_train_value = scores.accuracy();
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
}
// Test model
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
@@ -183,6 +188,8 @@ namespace platform {
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
partial_result.setDataset(fileName).setNotes(notes);
partial_result.setConfusionMatrices(confusion_matrices);
if (!no_train_score)
partial_result.setConfusionMatricesTrain(confusion_matrices_train);
addResult(partial_result);
}
}