Merge pull request 'reports' (#4) from reports into boostAode
Reviewed-on: https://gitea.rmontanana.es:11000/rmontanana/BayesNet/pulls/4
This commit is contained in:
commit
a062ebf445
11
.vscode/launch.json
vendored
11
.vscode/launch.json
vendored
@ -35,6 +35,17 @@
|
|||||||
],
|
],
|
||||||
"cwd": "/Users/rmontanana/Code/discretizbench",
|
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"type": "lldb",
|
||||||
|
"request": "launch",
|
||||||
|
"name": "manage",
|
||||||
|
"program": "${workspaceFolder}/build/src/Platform/manage",
|
||||||
|
"args": [
|
||||||
|
"-n",
|
||||||
|
"20"
|
||||||
|
],
|
||||||
|
"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Build & debug active file",
|
"name": "Build & debug active file",
|
||||||
"type": "cppdbg",
|
"type": "cppdbg",
|
||||||
|
10
src/Platform/BestResult.h
Normal file
10
src/Platform/BestResult.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#ifndef BESTRESULT_H
|
||||||
|
#define BESTRESULT_H
|
||||||
|
#include <string>
|
||||||
|
class BestResult {
|
||||||
|
public:
|
||||||
|
static std::string title() { return "STree_default (linear-ovo)"; }
|
||||||
|
static double score() { return 22.109799; }
|
||||||
|
static std::string scoreName() { return "accuracy"; }
|
||||||
|
};
|
||||||
|
#endif
|
@ -5,4 +5,6 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc)
|
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc)
|
||||||
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
add_executable(manage manage.cc Results.cc Report.cc)
|
||||||
|
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
|
target_link_libraries(manage "${TORCH_LIBRARIES}")
|
14
src/Platform/Colors.h
Normal file
14
src/Platform/Colors.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#ifndef COLORS_H
|
||||||
|
#define COLORS_H
|
||||||
|
class Colors {
|
||||||
|
public:
|
||||||
|
static std::string MAGENTA() { return "\033[1;35m"; }
|
||||||
|
static std::string BLUE() { return "\033[1;34m"; }
|
||||||
|
static std::string CYAN() { return "\033[1;36m"; }
|
||||||
|
static std::string GREEN() { return "\033[1;32m"; }
|
||||||
|
static std::string YELLOW() { return "\033[1;33m"; }
|
||||||
|
static std::string RED() { return "\033[1;31m"; }
|
||||||
|
static std::string WHITE() { return "\033[1;37m"; }
|
||||||
|
static std::string RESET() { return "\033[0m"; }
|
||||||
|
};
|
||||||
|
#endif // COLORS_H
|
10
src/Platform/Paths.h
Normal file
10
src/Platform/Paths.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#ifndef PATHS_H
|
||||||
|
#define PATHS_H
|
||||||
|
namespace platform {
|
||||||
|
class Paths {
|
||||||
|
public:
|
||||||
|
static std::string datasets() { return "datasets/"; }
|
||||||
|
static std::string results() { return "results/"; }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif
|
@ -1,4 +1,5 @@
|
|||||||
#include "Report.h"
|
#include "Report.h"
|
||||||
|
#include "BestResult.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
string headerLine(const string& text)
|
string headerLine(const string& text)
|
||||||
@ -28,10 +29,11 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
header();
|
header();
|
||||||
body();
|
body();
|
||||||
|
footer();
|
||||||
}
|
}
|
||||||
void Report::header()
|
void Report::header()
|
||||||
{
|
{
|
||||||
cout << string(MAXL, '*') << endl;
|
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
||||||
cout << headerLine("Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>());
|
cout << headerLine("Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>());
|
||||||
cout << headerLine(data["title"].get<string>());
|
cout << headerLine(data["title"].get<string>());
|
||||||
cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get<bool>() ? "True" : "False"));
|
cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get<bool>() ? "True" : "False"));
|
||||||
@ -42,26 +44,50 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void Report::body()
|
void Report::body()
|
||||||
{
|
{
|
||||||
cout << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
||||||
cout << "============================== ====== ===== === ======= ======= ======= =============== ================= ===============" << endl;
|
cout << "============================== ====== ===== === ======= ======= ======= =============== ================== ===============" << endl;
|
||||||
|
json lastResult;
|
||||||
|
totalScore = 0;
|
||||||
|
bool odd = true;
|
||||||
for (const auto& r : data["results"]) {
|
for (const auto& r : data["results"]) {
|
||||||
cout << setw(30) << left << r["dataset"].get<string>() << " ";
|
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||||
|
cout << color << setw(30) << left << r["dataset"].get<string>() << " ";
|
||||||
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
||||||
cout << setw(5) << right << r["features"].get<int>() << " ";
|
cout << setw(5) << right << r["features"].get<int>() << " ";
|
||||||
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
||||||
cout << setw(7) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
|
||||||
cout << setw(7) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
|
||||||
cout << setw(7) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
|
||||||
cout << setw(8) << right << setprecision(6) << fixed << r["score_test"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_test_std"].get<double>() << " ";
|
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>() << " ";
|
||||||
cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get<double>() << " ";
|
cout << setw(11) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
|
||||||
cout << " " << r["hyperparameters"].get<string>();
|
try {
|
||||||
|
cout << r["hyperparameters"].get<string>();
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cout << r["hyperparameters"];
|
||||||
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
lastResult = r;
|
||||||
|
totalScore += r["score"].get<double>();
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
if (data["results"].size() == 1) {
|
||||||
cout << string(MAXL, '*') << endl;
|
cout << string(MAXL, '*') << endl;
|
||||||
cout << headerLine("Train scores: " + fVector(r["scores_train"]));
|
cout << headerLine("Train scores: " + fVector(lastResult["scores_train"]));
|
||||||
cout << headerLine("Test scores: " + fVector(r["scores_test"]));
|
cout << headerLine("Test scores: " + fVector(lastResult["scores_test"]));
|
||||||
cout << headerLine("Train times: " + fVector(r["times_train"]));
|
cout << headerLine("Train times: " + fVector(lastResult["times_train"]));
|
||||||
cout << headerLine("Test times: " + fVector(r["times_test"]));
|
cout << headerLine("Test times: " + fVector(lastResult["times_test"]));
|
||||||
cout << string(MAXL, '*') << endl;
|
cout << string(MAXL, '*') << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void Report::footer()
|
||||||
|
{
|
||||||
|
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
||||||
|
auto score = data["score_name"].get<string>();
|
||||||
|
if (score == BestResult::scoreName()) {
|
||||||
|
cout << headerLine(score + " compared to " + BestResult::title() + " .: " + to_string(totalScore / BestResult::score()));
|
||||||
|
}
|
||||||
|
cout << string(MAXL, '*') << endl << Colors::RESET();
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
@ -3,6 +3,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "Colors.h"
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
const int MAXL = 121;
|
const int MAXL = 121;
|
||||||
@ -16,8 +17,10 @@ namespace platform {
|
|||||||
private:
|
private:
|
||||||
void header();
|
void header();
|
||||||
void body();
|
void body();
|
||||||
|
void footer();
|
||||||
string fromVector(const string& key);
|
string fromVector(const string& key);
|
||||||
json data;
|
json data;
|
||||||
|
double totalScore; // Total score of all results in a report
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
239
src/Platform/Results.cc
Normal file
239
src/Platform/Results.cc
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
#include <filesystem>
|
||||||
|
#include "platformUtils.h"
|
||||||
|
#include "Results.h"
|
||||||
|
#include "Report.h"
|
||||||
|
#include "BestResult.h"
|
||||||
|
#include "Colors.h"
|
||||||
|
namespace platform {
|
||||||
|
Result::Result(const string& path, const string& filename)
|
||||||
|
: path(path)
|
||||||
|
, filename(filename)
|
||||||
|
{
|
||||||
|
auto data = load();
|
||||||
|
date = data["date"];
|
||||||
|
score = 0;
|
||||||
|
for (const auto& result : data["results"]) {
|
||||||
|
score += result["score"].get<double>();
|
||||||
|
}
|
||||||
|
scoreName = data["score_name"];
|
||||||
|
if (scoreName == BestResult::scoreName()) {
|
||||||
|
score /= BestResult::score();
|
||||||
|
}
|
||||||
|
title = data["title"];
|
||||||
|
duration = data["duration"];
|
||||||
|
model = data["model"];
|
||||||
|
}
|
||||||
|
json Result::load() const
|
||||||
|
{
|
||||||
|
ifstream resultData(path + "/" + filename);
|
||||||
|
if (resultData.is_open()) {
|
||||||
|
json data = json::parse(resultData);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]");
|
||||||
|
}
|
||||||
|
void Results::load()
|
||||||
|
{
|
||||||
|
using std::filesystem::directory_iterator;
|
||||||
|
for (const auto& file : directory_iterator(path)) {
|
||||||
|
auto filename = file.path().filename().string();
|
||||||
|
if (filename.find(".json") != string::npos && filename.find("results_") == 0) {
|
||||||
|
auto result = Result(path, filename);
|
||||||
|
bool addResult = true;
|
||||||
|
if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName())
|
||||||
|
addResult = false;
|
||||||
|
if (addResult)
|
||||||
|
files.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
string Result::to_string() const
|
||||||
|
{
|
||||||
|
stringstream oss;
|
||||||
|
oss << date << " ";
|
||||||
|
oss << setw(12) << left << model << " ";
|
||||||
|
oss << setw(11) << left << scoreName << " ";
|
||||||
|
oss << right << setw(11) << setprecision(7) << fixed << score << " ";
|
||||||
|
oss << setw(9) << setprecision(3) << fixed << duration << " ";
|
||||||
|
oss << setw(50) << left << title << " ";
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
void Results::show() const
|
||||||
|
{
|
||||||
|
cout << Colors::GREEN() << "Results found: " << files.size() << endl;
|
||||||
|
cout << "-------------------" << endl;
|
||||||
|
auto i = 0;
|
||||||
|
cout << " # Date Model Score Name Score Duration Title" << endl;
|
||||||
|
cout << "=== ========== ============ =========== =========== ========= =============================================================" << endl;
|
||||||
|
bool odd = true;
|
||||||
|
for (const auto& result : files) {
|
||||||
|
auto color = odd ? Colors::BLUE() : Colors::CYAN();
|
||||||
|
cout << color << setw(3) << fixed << right << i++ << " ";
|
||||||
|
cout << result.to_string() << endl;
|
||||||
|
if (i == max && max != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
odd = !odd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int Results::getIndex(const string& intent) const
|
||||||
|
{
|
||||||
|
string color;
|
||||||
|
if (intent == "delete") {
|
||||||
|
color = Colors::RED();
|
||||||
|
} else {
|
||||||
|
color = Colors::YELLOW();
|
||||||
|
}
|
||||||
|
cout << color << "Choose result to " << intent << " (cancel=-1): ";
|
||||||
|
string line;
|
||||||
|
getline(cin, line);
|
||||||
|
int index = stoi(line);
|
||||||
|
if (index >= -1 && index < static_cast<int>(files.size())) {
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
cout << "Invalid index" << endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
void Results::report(const int index) const
|
||||||
|
{
|
||||||
|
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
||||||
|
auto data = files.at(index).load();
|
||||||
|
Report report(data);
|
||||||
|
report.show();
|
||||||
|
}
|
||||||
|
void Results::menu()
|
||||||
|
{
|
||||||
|
char option;
|
||||||
|
int index;
|
||||||
|
bool finished = false;
|
||||||
|
string filename, line, options = "qldhsr";
|
||||||
|
while (!finished) {
|
||||||
|
cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): ";
|
||||||
|
getline(cin, line);
|
||||||
|
if (line.size() == 0)
|
||||||
|
continue;
|
||||||
|
if (options.find(line[0]) != string::npos) {
|
||||||
|
if (line.size() > 1) {
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
option = line[0];
|
||||||
|
} else {
|
||||||
|
index = stoi(line);
|
||||||
|
if (index >= 0 && index < files.size()) {
|
||||||
|
report(index);
|
||||||
|
} else {
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
switch (option) {
|
||||||
|
case 'q':
|
||||||
|
finished = true;
|
||||||
|
break;
|
||||||
|
case 'l':
|
||||||
|
show();
|
||||||
|
break;
|
||||||
|
case 'd':
|
||||||
|
index = getIndex("delete");
|
||||||
|
if (index == -1)
|
||||||
|
break;
|
||||||
|
filename = files[index].getFilename();
|
||||||
|
cout << "Deleting " << filename << endl;
|
||||||
|
remove((path + "/" + filename).c_str());
|
||||||
|
files.erase(files.begin() + index);
|
||||||
|
cout << "File: " + filename + " deleted!" << endl;
|
||||||
|
show();
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
index = getIndex("hide");
|
||||||
|
if (index == -1)
|
||||||
|
break;
|
||||||
|
filename = files[index].getFilename();
|
||||||
|
cout << "Hiding " << filename << endl;
|
||||||
|
rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str());
|
||||||
|
files.erase(files.begin() + index);
|
||||||
|
show();
|
||||||
|
menu();
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
sortList();
|
||||||
|
show();
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
index = getIndex("report");
|
||||||
|
if (index == -1)
|
||||||
|
break;
|
||||||
|
report(index);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Results::sortList()
|
||||||
|
{
|
||||||
|
cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): ";
|
||||||
|
string line;
|
||||||
|
char option;
|
||||||
|
getline(cin, line);
|
||||||
|
if (line.size() == 0)
|
||||||
|
return;
|
||||||
|
if (line.size() > 1) {
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
option = line[0];
|
||||||
|
switch (option) {
|
||||||
|
case 'd':
|
||||||
|
sortDate();
|
||||||
|
break;
|
||||||
|
case 's':
|
||||||
|
sortScore();
|
||||||
|
break;
|
||||||
|
case 'u':
|
||||||
|
sortDuration();
|
||||||
|
break;
|
||||||
|
case 'm':
|
||||||
|
sortModel();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
cout << "Invalid option" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Results::sortDate()
|
||||||
|
{
|
||||||
|
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
||||||
|
return a.getDate() > b.getDate();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void Results::sortModel()
|
||||||
|
{
|
||||||
|
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
||||||
|
return a.getModel() > b.getModel();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void Results::sortDuration()
|
||||||
|
{
|
||||||
|
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
||||||
|
return a.getDuration() > b.getDuration();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void Results::sortScore()
|
||||||
|
{
|
||||||
|
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
|
||||||
|
return a.getScore() > b.getScore();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void Results::manage()
|
||||||
|
{
|
||||||
|
if (files.size() == 0) {
|
||||||
|
cout << "No results found!" << endl;
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
show();
|
||||||
|
menu();
|
||||||
|
cout << "Done!" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
56
src/Platform/Results.h
Normal file
56
src/Platform/Results.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#ifndef RESULTS_H
|
||||||
|
#define RESULTS_H
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
namespace platform {
|
||||||
|
using namespace std;
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
class Result {
|
||||||
|
public:
|
||||||
|
Result(const string& path, const string& filename);
|
||||||
|
json load() const;
|
||||||
|
string to_string() const;
|
||||||
|
string getFilename() const { return filename; };
|
||||||
|
string getDate() const { return date; };
|
||||||
|
double getScore() const { return score; };
|
||||||
|
string getTitle() const { return title; };
|
||||||
|
double getDuration() const { return duration; };
|
||||||
|
string getModel() const { return model; };
|
||||||
|
string getScoreName() const { return scoreName; };
|
||||||
|
private:
|
||||||
|
string path;
|
||||||
|
string filename;
|
||||||
|
string date;
|
||||||
|
double score;
|
||||||
|
string title;
|
||||||
|
double duration;
|
||||||
|
string model;
|
||||||
|
string scoreName;
|
||||||
|
};
|
||||||
|
class Results {
|
||||||
|
public:
|
||||||
|
Results(const string& path, const int max, const string& model, const string& score) : path(path), max(max), model(model), scoreName(score) { load(); };
|
||||||
|
void manage();
|
||||||
|
private:
|
||||||
|
string path;
|
||||||
|
int max;
|
||||||
|
string model;
|
||||||
|
string scoreName;
|
||||||
|
vector<Result> files;
|
||||||
|
void load(); // Loads the list of results
|
||||||
|
void show() const;
|
||||||
|
void report(const int index) const;
|
||||||
|
int getIndex(const string& intent) const;
|
||||||
|
void menu();
|
||||||
|
void sortList();
|
||||||
|
void sortDate();
|
||||||
|
void sortScore();
|
||||||
|
void sortModel();
|
||||||
|
void sortDuration();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -6,20 +6,19 @@
|
|||||||
#include "DotEnv.h"
|
#include "DotEnv.h"
|
||||||
#include "Models.h"
|
#include "Models.h"
|
||||||
#include "modelRegister.h"
|
#include "modelRegister.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
const string PATH_RESULTS = "results";
|
|
||||||
const string PATH_DATASETS = "datasets";
|
|
||||||
|
|
||||||
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||||
{
|
{
|
||||||
auto env = platform::DotEnv();
|
auto env = platform::DotEnv();
|
||||||
argparse::ArgumentParser program("BayesNetSample");
|
argparse::ArgumentParser program("main");
|
||||||
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
|
||||||
program.add_argument("-p", "--path")
|
program.add_argument("-p", "--path")
|
||||||
.help("folder where the data files are located, default")
|
.help("folder where the data files are located, default")
|
||||||
.default_value(string{ PATH_DATASETS }
|
.default_value(string{ platform::Paths::datasets() });
|
||||||
);
|
|
||||||
program.add_argument("-m", "--model")
|
program.add_argument("-m", "--model")
|
||||||
.help("Model to use " + platform::Models::instance()->toString())
|
.help("Model to use " + platform::Models::instance()->toString())
|
||||||
.action([](const std::string& value) {
|
.action([](const std::string& value) {
|
||||||
@ -115,7 +114,7 @@ int main(int argc, char** argv)
|
|||||||
experiment.go(filesToTest, path);
|
experiment.go(filesToTest, path);
|
||||||
experiment.setDuration(timer.getDuration());
|
experiment.setDuration(timer.getDuration());
|
||||||
if (saveResults)
|
if (saveResults)
|
||||||
experiment.save(PATH_RESULTS);
|
experiment.save(platform::Paths::results());
|
||||||
else
|
else
|
||||||
experiment.report();
|
experiment.report();
|
||||||
cout << "Done!" << endl;
|
cout << "Done!" << endl;
|
||||||
|
41
src/Platform/manage.cc
Normal file
41
src/Platform/manage.cc
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <argparse/argparse.hpp>
|
||||||
|
#include "platformUtils.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
#include "Results.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||||
|
{
|
||||||
|
argparse::ArgumentParser program("manage");
|
||||||
|
program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>();
|
||||||
|
program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)");
|
||||||
|
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||||
|
try {
|
||||||
|
program.parse_args(argc, argv);
|
||||||
|
auto number = program.get<int>("number");
|
||||||
|
if (number < 0) {
|
||||||
|
throw runtime_error("Number of results must be greater than or equal to 0");
|
||||||
|
}
|
||||||
|
auto model = program.get<string>("model");
|
||||||
|
auto score = program.get<string>("score");
|
||||||
|
}
|
||||||
|
catch (const exception& err) {
|
||||||
|
cerr << err.what() << endl;
|
||||||
|
cerr << program;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
return program;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
auto program = manageArguments(argc, argv);
|
||||||
|
auto number = program.get<int>("number");
|
||||||
|
auto model = program.get<string>("model");
|
||||||
|
auto score = program.get<string>("score");
|
||||||
|
auto results = platform::Results(platform::Paths::results(), number, model, score);
|
||||||
|
results.manage();
|
||||||
|
return 0;
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
#include "platformUtils.h"
|
#include "platformUtils.h"
|
||||||
|
#include "Paths.h"
|
||||||
|
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
|
||||||
@ -85,7 +86,7 @@ tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadData
|
|||||||
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name)
|
tuple<vector<vector<int>>, vector<int>, vector<string>, string, map<string, vector<int>>> loadFile(const string& name)
|
||||||
{
|
{
|
||||||
auto handler = ArffFiles();
|
auto handler = ArffFiles();
|
||||||
handler.load(PATH + static_cast<string>(name) + ".arff");
|
handler.load(platform::Paths::datasets() + static_cast<string>(name) + ".arff");
|
||||||
// Get Dataset X, y
|
// Get Dataset X, y
|
||||||
vector<mdlp::samples_t>& X = handler.getX();
|
vector<mdlp::samples_t>& X = handler.getX();
|
||||||
mdlp::labels_t& y = handler.getY();
|
mdlp::labels_t& y = handler.getY();
|
||||||
|
Loading…
Reference in New Issue
Block a user