Fix problem with num of classes in pyclassifiers experiments

This commit is contained in:
2024-05-17 14:05:09 +02:00
parent 696c0564a7
commit a3c4bde460

View File

@@ -84,6 +84,7 @@ namespace platform {
auto samples = datasets.getNSamples(fileName); auto samples = datasets.getNSamples(fileName);
auto className = datasets.getClassName(fileName); auto className = datasets.getClassName(fileName);
auto labels = datasets.getLabels(fileName); auto labels = datasets.getLabels(fileName);
int num_classes = states[className].size() == 0 ? labels.size() : states[className].size();
if (!quiet) { if (!quiet) {
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
} }
@@ -153,7 +154,7 @@ namespace platform {
// Score train // Score train
if (!no_train_score) { if (!no_train_score) {
auto y_predict = clf->predict(X_train); auto y_predict = clf->predict(X_train);
Scores scores(y_train, y_predict, states[className].size(), labels); Scores scores(y_train, y_predict, num_classes, labels);
accuracy_train_value = scores.accuracy(); accuracy_train_value = scores.accuracy();
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true)); confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
} }
@@ -162,7 +163,7 @@ namespace platform {
showProgress(nfold + 1, getColor(clf->getStatus()), "c"); showProgress(nfold + 1, getColor(clf->getStatus()), "c");
test_timer.start(); test_timer.start();
auto y_predict = clf->predict(X_test); auto y_predict = clf->predict(X_test);
Scores scores(y_test, y_predict, states[className].size(), labels); Scores scores(y_test, y_predict, num_classes, labels);
auto accuracy_test_value = scores.accuracy(); auto accuracy_test_value = scores.accuracy();
test_time[item] = test_timer.getDuration(); test_time[item] = test_timer.getDuration();
accuracy_train[item] = accuracy_train_value; accuracy_train[item] = accuracy_train_value;