Add Graphs to results

Add bin5..bin10 q & u discretizers algos
Fix trouble in computing states
Update mdlp to 2.0.0
This commit is contained in:
2024-07-11 11:23:20 +02:00
parent 3acc34e4c6
commit 26dfe6d056
15 changed files with 83 additions and 17 deletions

View File

@@ -24,7 +24,24 @@ namespace platform {
{
std::cout << result.getJson().dump(4) << std::endl;
}
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files)
void Experiment::saveGraph()
{
std::cout << "Saving graphs..." << std::endl;
auto data = result.getJson();
for (const auto& item : data["results"]) {
auto graphs = item["graph"];
int i = 0;
for (const auto& graph : graphs) {
i++;
auto fileName = Paths::graphs() + result.getFilename() + "_graph_" + item["dataset"].get<std::string>() + "_" + std::to_string(i) + ".dot";
auto file = std::ofstream(fileName);
file << graph.get<std::string>();
file.close();
std::cout << "Graph saved in " << fileName << std::endl;
}
}
}
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
{
for (auto fileName : filesToProcess) {
if (fileName.size() > max_name)
@@ -48,7 +65,7 @@ namespace platform {
for (auto fileName : filesToProcess) {
if (!quiet)
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
cross_validation(fileName, quiet, no_train_score, generate_fold_files);
cross_validation(fileName, quiet, no_train_score, generate_fold_files, graph);
if (!quiet)
std::cout << std::endl;
}
@@ -71,7 +88,7 @@ namespace platform {
void showProgress(int fold, const std::string& color, const std::string& phase)
{
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
}
@@ -113,7 +130,7 @@ namespace platform {
file << output.dump(4);
file.close();
}
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
{
//
// Load dataset and prepare data
@@ -151,6 +168,7 @@ namespace platform {
json confusion_matrices = json::array();
json confusion_matrices_train = json::array();
std::vector<std::string> notes;
std::vector<std::string> graphs;
Timer train_timer, test_timer;
int item = 0;
bool first_seed = true;
@@ -176,6 +194,8 @@ namespace platform {
//
for (int nfold = 0; nfold < nfolds; nfold++) {
auto clf = Models::instance()->create(result.getModel());
if (!quiet)
showProgress(nfold + 1, getColor(clf->getStatus()), "-");
setModelVersion(clf->getVersion());
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
@@ -237,6 +257,13 @@ namespace platform {
partial_result.addTimeTrain(train_time[item].item<double>());
partial_result.addTimeTest(test_time[item].item<double>());
item++;
if (graph) {
std::string result = "";
for (const auto& line : clf->graph()) {
result += line + "\n";
}
graphs.push_back(result);
}
}
if (!quiet)
std::cout << "end. " << flush;
@@ -245,6 +272,7 @@ namespace platform {
//
// Store result totals in Result
//
partial_result.setGraph(graphs);
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());