Add train classification report
This commit is contained in:
@@ -102,6 +102,7 @@ namespace platform {
|
||||
auto edges = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto num_states = torch::zeros({ nResults }, torch::kFloat64);
|
||||
json confusion_matrices = json::array();
|
||||
json confusion_matrices_train = json::array();
|
||||
std::vector<std::string> notes;
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
@@ -150,8 +151,12 @@ namespace platform {
|
||||
train_time[item] = train_timer.getDuration();
|
||||
double accuracy_train_value = 0.0;
|
||||
// Score train
|
||||
if (!no_train_score)
|
||||
accuracy_train_value = clf->score(X_train, y_train);
|
||||
if (!no_train_score) {
|
||||
auto y_predict = clf->predict(X_train);
|
||||
Scores scores(y_train, y_predict, states[className].size(), labels);
|
||||
accuracy_train_value = scores.accuracy();
|
||||
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||
}
|
||||
// Test model
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||
@@ -183,6 +188,8 @@ namespace platform {
|
||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||
partial_result.setDataset(fileName).setNotes(notes);
|
||||
partial_result.setConfusionMatrices(confusion_matrices);
|
||||
if (!no_train_score)
|
||||
partial_result.setConfusionMatricesTrain(confusion_matrices_train);
|
||||
addResult(partial_result);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user