From 4545f76667f71d34aeec19b01ddb32d346806ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 2 Sep 2024 18:14:53 +0200 Subject: [PATCH] Begin adding TeX output to b_best -m any command --- src/best/BestResults.cpp | 109 +++++++++++++++++++++++++++------- src/best/BestResults.h | 6 +- src/best/BestResultsExcel.cpp | 2 +- src/commands/b_best.cpp | 6 +- src/common/Paths.h | 9 +++ 5 files changed, 106 insertions(+), 26 deletions(-) diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index c0b8622..a90adde 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -52,7 +52,7 @@ namespace platform { } } if (update) { - bests[datasetName] = { item.at("score").get(), item.at("hyperparameters"), file }; + bests[datasetName] = { item.at("score").get(), item.at("hyperparameters"), file, item.at("score_std").get() }; } } } @@ -210,15 +210,56 @@ namespace platform { table["dateTable"] = ftime_to_string(maxDate); return table; } - void BestResults::printTableResults(std::vector models, json table) + double compute_std(std::vector values, double mean) + { + // Compute standard devation of the values + double sum = 0.0; + for (const auto& value : values) { + sum += std::pow(value - mean, 2); + } + double variance = sum / values.size(); + return std::sqrt(variance); + } + void BestResults::printTableResults(std::vector models, json table, bool tex) { std::stringstream oss; oss << Colors::GREEN() << "Best results for " << score << " as of " << table.at("dateTable").get() << std::endl; + std::FILE* output_tex; std::cout << oss.str(); std::cout << std::string(oss.str().size() - 8, '-') << std::endl; std::cout << Colors::GREEN() << " # " << std::setw(maxDatasetName + 1) << std::left << std::string("Dataset"); + if (tex) { + auto file_name = Paths::tex_output(); + output_tex = fopen(file_name.c_str(), "w"); + if (output_tex == NULL) { + std::cerr << "Error opening file "<< file_name << std::endl; + exit(1); + } + fprintf(output_tex, "%% This file has been generated by the platform program\n"); + fprintf(output_tex, "%% Date: %s\n", table.at("dateTable").get().c_str()); + fprintf(output_tex, "%%\n"); + fprintf(output_tex, "%% Table of results\n"); + fprintf(output_tex, "%%\n"); + fprintf(output_tex, "\\begin{table}[htbp] \n"); + fprintf(output_tex, "\\centering \n"); + fprintf(output_tex, "\\tiny \n"); + fprintf(output_tex, "\\renewcommand{\\arraystretch }{1.2} \n"); + fprintf(output_tex, "\\renewcommand{\\tabcolsep }{0.07cm} \n"); + fprintf(output_tex, "\\caption{Accuracy results(mean ± std) for all the algorithms and datasets} \n"); + fprintf(output_tex, "\\label{tab:results_accuracy}\n"); + fprintf(output_tex, "\\begin{tabular} {{r%s}}\n", std::string(models.size(), 'c').c_str()); + fprintf(output_tex, "\\hline \n"); + fprintf(output_tex, "Id"); + } for (const auto& model : models) { std::cout << std::setw(maxModelName) << std::left << model << " "; + if (tex) { + fprintf(output_tex, "& %s ", model.c_str()); + } + } + if (tex) { + fprintf(output_tex, "\\\\ \n"); + fprintf(output_tex, "\\hline \n"); } std::cout << std::endl; std::cout << "=== " << std::string(maxDatasetName, '=') << " "; @@ -227,12 +268,10 @@ namespace platform { } std::cout << std::endl; auto i = 0; - std::map totals; + std::map> totals; int nDatasets = table.begin().value().size(); - for (const auto& model : models) { - totals[model] = 0.0; - } auto datasets = getDatasets(table.begin().value()); + for (auto const& dataset_ : datasets) { auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN(); std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " "; @@ -251,6 +290,9 @@ namespace platform { maxValue = value; } } + if (tex) { + fprintf(output_tex, "%d ", i); + } // Print the row with red colors on max values for (const auto& model : models) { std::string efectiveColor = color; @@ -267,30 +309,53 @@ namespace platform { if (value == -1) { std::cout << Colors::YELLOW() << std::setw(maxModelName) << std::right << "N/A" << " "; } else { - totals[model] += value; + totals[model].push_back(value); std::cout << efectiveColor << std::setw(maxModelName) << std::setprecision(maxModelName - 2) << std::fixed << value << " "; } + if (tex) { + auto std_value = table[model].at(dataset_).at(3).get(); + const char* bold = value == maxValue ? "\\bfseries" : ""; + fprintf(output_tex, "& %s %0.4f±%0.3f", bold, value, std_value); + } } std::cout << std::endl; + if (tex) { + fprintf(output_tex, "\\\\\n"); + } } std::cout << Colors::GREEN() << "=== " << std::string(maxDatasetName, '=') << " "; for (const auto& model : models) { std::cout << std::string(maxModelName, '=') << " "; } std::cout << std::endl; - std::cout << Colors::GREEN() << " Totals" << std::string(maxDatasetName - 6, '.') << " "; + std::cout << Colors::GREEN() << " Average" << std::string(maxDatasetName - 7, '.') << " "; double max_value = 0.0; + std::string best_model = ""; for (const auto& total : totals) { - if (total.second > max_value) { - max_value = total.second; + auto actual = std::reduce(total.second.begin(), total.second.end()); + if (actual > max_value) { + max_value = actual; + best_model = total.first; } } + if (tex) { + fprintf(output_tex, "\\hline \n"); + fprintf(output_tex, "Average "); + } for (const auto& model : models) { - std::string efectiveColor = Colors::GREEN(); - if (totals[model] == max_value) { - efectiveColor = Colors::RED(); + std::string efectiveColor = model == best_model ? Colors::RED() : Colors::GREEN(); + double value = std::reduce(totals[model].begin(), totals[model].end()) / nDatasets; + double std_value = compute_std(totals[model], value); + std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << value << " "; + if (tex) { + const char* bold = model == best_model ? "\\bfseries" : ""; + fprintf(output_tex, "& %s %0.4f±%0.3f", bold, value, std_value); } - std::cout << efectiveColor << std::right << std::setw(maxModelName) << std::setprecision(maxModelName - 4) << std::fixed << totals[model] << " "; + } + if (tex) { + // Footer for TeX + fprintf(output_tex, "\\ \n\\hline \n\\end{tabular}\n\\end{table}\n"); + fclose(output_tex); } std::cout << std::endl; } @@ -304,17 +369,17 @@ namespace platform { std::vector datasets = getDatasets(table.begin().value()); BestResultsExcel excel_report(score, datasets); excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model)); - messageExcelFile(excel_report.getFileName()); + messageOutputFile("Excel", excel_report.getFileName()); } } - void BestResults::reportAll(bool excel) + void BestResults::reportAll(bool excel, bool tex) { auto models = getModels(); // Build the table of results json table = buildTableResults(models); std::vector datasets = getDatasets(table.begin().value()); // Print the table of results - printTableResults(models, table); + printTableResults(models, table, tex); // Compute the Friedman test std::map> ranksModels; if (friedman) { @@ -323,6 +388,9 @@ namespace platform { stats.postHocHolmTest(result); ranksModels = stats.getRanks(); } + if (tex) { + messageOutputFile("TeX", Paths::tex_output()); + } if (excel) { BestResultsExcel excel(score, datasets); excel.reportAll(models, table, ranksModels, friedman, significance); @@ -345,11 +413,12 @@ namespace platform { model = models.at(idx); excel.reportSingle(model, path + Paths::bestResultsFile(score, model)); } - messageExcelFile(excel.getFileName()); + messageOutputFile("Excel", excel.getFileName()); } } - void BestResults::messageExcelFile(const std::string& fileName) + void BestResults::messageOutputFile(const std::string& title, const std::string& fileName) { - std::cout << Colors::YELLOW() << "** Excel file generated: " << fileName << Colors::RESET() << std::endl; + std::cout << Colors::YELLOW() << "** " << std::setw(5) << std::left << title + << " file generated: " << fileName << Colors::RESET() << std::endl; } } \ No newline at end of file diff --git a/src/best/BestResults.h b/src/best/BestResults.h index 2098dea..6ecee03 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -13,15 +13,15 @@ namespace platform { } std::string build(); void reportSingle(bool excel); - void reportAll(bool excel); + void reportAll(bool excel, bool tex); void buildAll(); private: std::vector getModels(); std::vector getDatasets(json table); std::vector loadResultFiles(); - void messageExcelFile(const std::string& fileName); + void messageOutputFile(const std::string& title, const std::string& fileName); json buildTableResults(std::vector models); - void printTableResults(std::vector models, json table); + void printTableResults(std::vector models, json table, bool tex); json loadFile(const std::string& fileName); void listFile(); std::string path; diff --git a/src/best/BestResultsExcel.cpp b/src/best/BestResultsExcel.cpp index 1bc36bd..172be07 100644 --- a/src/best/BestResultsExcel.cpp +++ b/src/best/BestResultsExcel.cpp @@ -32,7 +32,7 @@ namespace platform { } BestResultsExcel::BestResultsExcel(const std::string& score, const std::vector& datasets) : score(score), datasets(datasets) { - file_name = "BestResults.xlsx"; + file_name = Paths::bestResultsExcel(); workbook = workbook_new(getFileName().c_str()); setProperties("Best Results"); int maxDatasetName = (*max_element(datasets.begin(), datasets.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); diff --git a/src/commands/b_best.cpp b/src/commands/b_best.cpp index bacee15..4e1adc5 100644 --- a/src/commands/b_best.cpp +++ b/src/commands/b_best.cpp @@ -16,6 +16,7 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); program.add_argument("--friedman").help("Friedman test").default_value(false).implicit_value(true); program.add_argument("--excel").help("Output to excel").default_value(false).implicit_value(true); + program.add_argument("--tex").help("Output result table to TeX file").default_value(false).implicit_value(true); program.add_argument("--level").help("significance level").default_value(0.05).scan<'g', double>().action([](const std::string& value) { try { auto k = std::stod(value); @@ -37,7 +38,7 @@ int main(int argc, char** argv) argparse::ArgumentParser program("b_best", { platform_project_version.begin(), platform_project_version.end() }); manageArguments(program); std::string model, dataset, score; - bool build, report, friedman, excel; + bool build, report, friedman, excel, tex; double level; try { program.parse_args(argc, argv); @@ -46,6 +47,7 @@ int main(int argc, char** argv) score = program.get("score"); friedman = program.get("friedman"); excel = program.get("excel"); + tex = program.get("tex"); level = program.get("level"); if (model == "" || score == "") { throw std::runtime_error("Model and score name must be supplied"); @@ -65,7 +67,7 @@ int main(int argc, char** argv) auto results = platform::BestResults(platform::Paths::results(), score, model, dataset, friedman, level); if (model == "any") { results.buildAll(); - results.reportAll(excel); + results.reportAll(excel, tex); } else { std::string fileName = results.build(); std::cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << std::endl; diff --git a/src/common/Paths.h b/src/common/Paths.h index f8b8f8e..9e8fcc6 100644 --- a/src/common/Paths.h +++ b/src/common/Paths.h @@ -11,6 +11,7 @@ namespace platform { static std::string excel() { return "excel/"; } static std::string grid() { return "grid/"; } static std::string graphs() { return "graphs/"; } + static std::string tex() { return "tex/"; } static std::string datasets() { auto env = platform::DotEnv(); @@ -36,6 +37,10 @@ namespace platform { { return "best_results_" + score + "_" + model + ".json"; } + static std::string bestResultsExcel() + { + return "BestResults.excel"; + } static std::string excelResults() { return "some_results.xlsx"; } static std::string grid_input(const std::string& model) { @@ -45,6 +50,10 @@ namespace platform { { return grid() + "grid_" + model + "_output.json"; } + static std::string tex_output() + { + return "results.tex"; + } }; } #endif \ No newline at end of file