diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 8657c12..4a5112c 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -84,6 +84,7 @@ namespace platform { auto samples = datasets.getNSamples(fileName); auto className = datasets.getClassName(fileName); auto labels = datasets.getLabels(fileName); + int num_classes = states[className].size() == 0 ? labels.size() : states[className].size(); if (!quiet) { std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } @@ -153,7 +154,7 @@ namespace platform { // Score train if (!no_train_score) { auto y_predict = clf->predict(X_train); - Scores scores(y_train, y_predict, states[className].size(), labels); + Scores scores(y_train, y_predict, num_classes, labels); accuracy_train_value = scores.accuracy(); confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); } @@ -162,7 +163,7 @@ namespace platform { showProgress(nfold + 1, getColor(clf->getStatus()), "c"); test_timer.start(); auto y_predict = clf->predict(X_test); - Scores scores(y_test, y_predict, states[className].size(), labels); + Scores scores(y_test, y_predict, num_classes, labels); auto accuracy_test_value = scores.accuracy(); test_time[item] = test_timer.getDuration(); accuracy_train[item] = accuracy_train_value;