Add traintest split in gridsearch

This commit is contained in:
2024-06-07 11:05:59 +02:00
parent 5dd3deca1a
commit 361c51d864
8 changed files with 213 additions and 247 deletions

View File

@@ -42,35 +42,37 @@ namespace platform {
sline += "\n";
header.push_back(sline);
int num = 0;
for (const auto& dataset : datasets.getNames()) {
for (const auto& dataset_name : datasets.getNames()) {
std::stringstream line;
line.imbue(loc);
auto color = num % 2 ? Colors::CYAN() : Colors::BLUE();
line << color << setw(3) << right << num++ << " ";
line << setw(maxName) << left << dataset << " ";
datasets.loadDataset(dataset);
auto nSamples = datasets.getNSamples(dataset);
line << setw(maxName) << left << dataset_name << " ";
auto& dataset = datasets.getDataset(dataset_name);
dataset.load();
auto nSamples = dataset.getNSamples();
line << setw(6) << right << nSamples << " ";
auto nFeatures = datasets.getFeatures(dataset).size();
auto nFeatures = dataset.getFeatures().size();
line << setw(5) << right << nFeatures << " ";
auto numericFeatures = datasets.getNumericFeatures(dataset);
auto numericFeatures = dataset.getNumericFeatures();
auto num = std::count(numericFeatures.begin(), numericFeatures.end(), true);
line << setw(5) << right << num << " ";
line << setw(3) << right << datasets.getNClasses(dataset) << " ";
auto nClasses = dataset.getNClasses();
line << setw(3) << right << nClasses << " ";
std::string sep = "";
oss.str("");
for (auto number : datasets.getClassesCounts(dataset)) {
for (auto number : dataset.getClassesCounts()) {
oss << sep << std::setprecision(2) << fixed << (float)number / nSamples * 100.0 << "% (" << number << ")";
sep = " / ";
}
split_lines(maxName, line.str(), oss.str());
// Store data for Excel report
data[dataset] = json::object();
data[dataset]["samples"] = nSamples;
data[dataset]["features"] = datasets.getFeatures(dataset).size();
data[dataset]["numericFeatures"] = num;
data[dataset]["classes"] = datasets.getNClasses(dataset);
data[dataset]["balance"] = oss.str();
data[dataset_name] = json::object();
data[dataset_name]["samples"] = nSamples;
data[dataset_name]["features"] = nFeatures;
data[dataset_name]["numericFeatures"] = num;
data[dataset_name]["classes"] = nClasses;
data[dataset_name]["balance"] = oss.str();
}
}
}

View File

@@ -61,12 +61,13 @@ namespace platform {
}
} else {
if (data["score_name"].get<std::string>() == "accuracy") {
auto dt = Datasets(false, Paths::datasets());
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
auto datasets = Datasets(false, Paths::datasets());
auto& dt = datasets.getDataset(dataset);
dt.load();
auto numClasses = dt.getNClasses();
if (numClasses == 2) {
std::vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
std::vector<int> distribution = dt.getClassesCounts();
double nSamples = dt.getNSamples();
std::vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {