Fix problem with num of classes in pyclassifiers experiments
This commit is contained in:
@@ -84,6 +84,7 @@ namespace platform {
|
|||||||
auto samples = datasets.getNSamples(fileName);
|
auto samples = datasets.getNSamples(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
auto labels = datasets.getLabels(fileName);
|
auto labels = datasets.getLabels(fileName);
|
||||||
|
int num_classes = states[className].size() == 0 ? labels.size() : states[className].size();
|
||||||
if (!quiet) {
|
if (!quiet) {
|
||||||
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush;
|
||||||
}
|
}
|
||||||
@@ -153,7 +154,7 @@ namespace platform {
|
|||||||
// Score train
|
// Score train
|
||||||
if (!no_train_score) {
|
if (!no_train_score) {
|
||||||
auto y_predict = clf->predict(X_train);
|
auto y_predict = clf->predict(X_train);
|
||||||
Scores scores(y_train, y_predict, states[className].size(), labels);
|
Scores scores(y_train, y_predict, num_classes, labels);
|
||||||
accuracy_train_value = scores.accuracy();
|
accuracy_train_value = scores.accuracy();
|
||||||
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||||
}
|
}
|
||||||
@@ -162,7 +163,7 @@ namespace platform {
|
|||||||
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
showProgress(nfold + 1, getColor(clf->getStatus()), "c");
|
||||||
test_timer.start();
|
test_timer.start();
|
||||||
auto y_predict = clf->predict(X_test);
|
auto y_predict = clf->predict(X_test);
|
||||||
Scores scores(y_test, y_predict, states[className].size(), labels);
|
Scores scores(y_test, y_predict, num_classes, labels);
|
||||||
auto accuracy_test_value = scores.accuracy();
|
auto accuracy_test_value = scores.accuracy();
|
||||||
test_time[item] = test_timer.getDuration();
|
test_time[item] = test_timer.getDuration();
|
||||||
accuracy_train[item] = accuracy_train_value;
|
accuracy_train[item] = accuracy_train_value;
|
||||||
|
Reference in New Issue
Block a user