Add Graphs to results
Add bin5..bin10 q & u discretizers algos Fix trouble in computing states Update mdlp to 2.0.0
This commit is contained in:
@@ -24,7 +24,24 @@ namespace platform {
|
||||
{
|
||||
std::cout << result.getJson().dump(4) << std::endl;
|
||||
}
|
||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files)
|
||||
void Experiment::saveGraph()
|
||||
{
|
||||
std::cout << "Saving graphs..." << std::endl;
|
||||
auto data = result.getJson();
|
||||
for (const auto& item : data["results"]) {
|
||||
auto graphs = item["graph"];
|
||||
int i = 0;
|
||||
for (const auto& graph : graphs) {
|
||||
i++;
|
||||
auto fileName = Paths::graphs() + result.getFilename() + "_graph_" + item["dataset"].get<std::string>() + "_" + std::to_string(i) + ".dot";
|
||||
auto file = std::ofstream(fileName);
|
||||
file << graph.get<std::string>();
|
||||
file.close();
|
||||
std::cout << "Graph saved in " << fileName << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
void Experiment::go(std::vector<std::string> filesToProcess, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
|
||||
{
|
||||
for (auto fileName : filesToProcess) {
|
||||
if (fileName.size() > max_name)
|
||||
@@ -48,7 +65,7 @@ namespace platform {
|
||||
for (auto fileName : filesToProcess) {
|
||||
if (!quiet)
|
||||
std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush;
|
||||
cross_validation(fileName, quiet, no_train_score, generate_fold_files);
|
||||
cross_validation(fileName, quiet, no_train_score, generate_fold_files, graph);
|
||||
if (!quiet)
|
||||
std::cout << std::endl;
|
||||
}
|
||||
@@ -71,7 +88,7 @@ namespace platform {
|
||||
|
||||
void showProgress(int fold, const std::string& color, const std::string& phase)
|
||||
{
|
||||
std::string prefix = phase == "a" ? "" : "\b\b\b\b";
|
||||
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
|
||||
std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush;
|
||||
|
||||
}
|
||||
@@ -113,7 +130,7 @@ namespace platform {
|
||||
file << output.dump(4);
|
||||
file.close();
|
||||
}
|
||||
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files)
|
||||
void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score, bool generate_fold_files, bool graph)
|
||||
{
|
||||
//
|
||||
// Load dataset and prepare data
|
||||
@@ -151,6 +168,7 @@ namespace platform {
|
||||
json confusion_matrices = json::array();
|
||||
json confusion_matrices_train = json::array();
|
||||
std::vector<std::string> notes;
|
||||
std::vector<std::string> graphs;
|
||||
Timer train_timer, test_timer;
|
||||
int item = 0;
|
||||
bool first_seed = true;
|
||||
@@ -176,6 +194,8 @@ namespace platform {
|
||||
//
|
||||
for (int nfold = 0; nfold < nfolds; nfold++) {
|
||||
auto clf = Models::instance()->create(result.getModel());
|
||||
if (!quiet)
|
||||
showProgress(nfold + 1, getColor(clf->getStatus()), "-");
|
||||
setModelVersion(clf->getVersion());
|
||||
auto valid = clf->getValidHyperparameters();
|
||||
hyperparameters.check(valid, fileName);
|
||||
@@ -237,6 +257,13 @@ namespace platform {
|
||||
partial_result.addTimeTrain(train_time[item].item<double>());
|
||||
partial_result.addTimeTest(test_time[item].item<double>());
|
||||
item++;
|
||||
if (graph) {
|
||||
std::string result = "";
|
||||
for (const auto& line : clf->graph()) {
|
||||
result += line + "\n";
|
||||
}
|
||||
graphs.push_back(result);
|
||||
}
|
||||
}
|
||||
if (!quiet)
|
||||
std::cout << "end. " << flush;
|
||||
@@ -245,6 +272,7 @@ namespace platform {
|
||||
//
|
||||
// Store result totals in Result
|
||||
//
|
||||
partial_result.setGraph(graphs);
|
||||
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||
|
Reference in New Issue
Block a user