Add Classification report to end of experiment if only one dataset is tested

This commit is contained in:
2024-05-10 14:11:51 +02:00
parent ec0268c514
commit 6d4117d188
6 changed files with 24 additions and 4 deletions

View File

@@ -64,5 +64,6 @@ add_executable(
common/Datasets.cpp common/Dataset.cpp common/Datasets.cpp common/Dataset.cpp
reports/ReportConsole.cpp reports/ReportExcel.cpp reports/ReportExcelCompared.cpp reports/ReportBase.cpp reports/ExcelFile.cpp reports/DatasetsConsole.cpp reports/ResultsDatasetConsole.cpp reports/ReportsPaged.cpp reports/ReportConsole.cpp reports/ReportExcel.cpp reports/ReportExcelCompared.cpp reports/ReportBase.cpp reports/ExcelFile.cpp reports/DatasetsConsole.cpp reports/ResultsDatasetConsole.cpp reports/ReportsPaged.cpp
results/Result.cpp results/ResultsDataset.cpp results/Result.cpp results/ResultsDataset.cpp
main/Scores.cpp
) )
target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp "${BayesNet}") target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp "${BayesNet}")

View File

@@ -189,8 +189,10 @@ int main(int argc, char** argv)
if (saveResults) { if (saveResults) {
experiment.saveResult(); experiment.saveResult();
} }
if (!quiet) if (!quiet) {
experiment.report(); // Classification report if only one dataset is tested
experiment.report(filesToTest.size() == 1);
}
std::cout << "Done!" << std::endl; std::cout << "Done!" << std::endl;
return 0; return 0;
} }

View File

@@ -11,10 +11,13 @@ namespace platform {
{ {
result.save(); result.save();
} }
void Experiment::report() void Experiment::report(bool classification_report)
{ {
ReportConsole report(result.getJson()); ReportConsole report(result.getJson());
report.show(); report.show();
if (classification_report) {
report.showClassificationReport();
}
} }
void Experiment::show() void Experiment::show()
{ {

View File

@@ -32,7 +32,7 @@ namespace platform {
void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score); void go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score);
void saveResult(); void saveResult();
void show(); void show();
void report(); void report(bool classification_report = false);
private: private:
Result result; Result result;
bool discretized{ false }, stratified{ false }; bool discretized{ false }, stratified{ false };

View File

@@ -3,6 +3,7 @@
#include "best/BestScore.h" #include "best/BestScore.h"
#include "common/CLocale.h" #include "common/CLocale.h"
#include "ReportConsole.h" #include "ReportConsole.h"
#include "main/Scores.h"
namespace platform { namespace platform {
std::string ReportConsole::headerLine(const std::string& text, int utf = 0) std::string ReportConsole::headerLine(const std::string& text, int utf = 0)
@@ -164,4 +165,16 @@ namespace platform {
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!"); std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
} }
} }
void ReportConsole::showClassificationReport()
{
if (data["results"].size() > 1)
return;
auto item = data["results"][0];
auto scores = Scores(item["confusion_matrices"][0]);
for (int i = 1; i < item["confusion_matrices"].size(); i++) {
auto score = Scores(item["confusion_matrices"][i]);
scores.aggregate(score);
}
std::cout << Colors::BLUE() << scores.classification_report() << Colors::RESET();
}
} }

View File

@@ -14,6 +14,7 @@ namespace platform {
std::string fileReport(); std::string fileReport();
std::string getHeader() { do_header(); do_body(); return sheader.str(); } std::string getHeader() { do_header(); do_body(); return sheader.str(); }
std::vector<std::string>& getBody() { return vbody; } std::vector<std::string>& getBody() { return vbody; }
void showClassificationReport();
private: private:
int selectedIndex; int selectedIndex;
std::string headerLine(const std::string& text, int utf); std::string headerLine(const std::string& text, int utf);