Complete json output compatible with benchmark

This commit is contained in:
2023-07-26 19:01:39 +02:00
parent 6f7fb290b0
commit 3e954ba841
8 changed files with 88 additions and 16 deletions

View File

@@ -2,36 +2,48 @@
namespace platform {
using json = nlohmann::json;
string get_date_time()
string get_date()
{
time_t rawtime;
tm* timeinfo;
time(&rawtime);
timeinfo = std::localtime(&rawtime);
std::ostringstream oss;
oss << std::put_time(timeinfo, "%Y-%m-%d_%H:%M:%S");
oss << std::put_time(timeinfo, "%Y-%m-%d");
return oss.str();
}
string get_time()
{
time_t rawtime;
tm* timeinfo;
time(&rawtime);
timeinfo = std::localtime(&rawtime);
std::ostringstream oss;
oss << std::put_time(timeinfo, "%H:%M:%S");
return oss.str();
}
string Experiment::get_file_name()
{
string date_time = get_date_time();
string result = "results_" + score_name + "_" + model + "_" + platform + "_" + date_time + "_" + (stratified ? "1" : "0") + ".json";
string result = "results_" + score_name + "_" + model + "_" + platform + "_" + get_date() + "_" + get_time() + "_" + (stratified ? "1" : "0") + ".json";
return result;
}
json Experiment::build_json()
{
json result;
result["title"] = title;
result["date"] = get_date();
result["time"] = get_time();
result["model"] = model;
result["version"] = model_version;
result["platform"] = platform;
result["score_name"] = score_name;
result["model_version"] = model_version;
result["language"] = language;
result["language_version"] = language_version;
result["discretized"] = discretized;
result["stratified"] = stratified;
result["nfolds"] = nfolds;
result["random_seeds"] = random_seeds;
result["folds"] = nfolds;
result["seeds"] = random_seeds;
result["duration"] = duration;
result["results"] = json::array();
for (auto& r : results) {
@@ -43,12 +55,19 @@ namespace platform {
j["classes"] = r.getClasses();
j["score_train"] = r.getScoreTrain();
j["score_test"] = r.getScoreTest();
j["score"] = r.getScoreTest();
j["score_std"] = r.getScoreTestStd();
j["score_train_std"] = r.getScoreTrainStd();
j["score_test_std"] = r.getScoreTestStd();
j["train_time"] = r.getTrainTime();
j["train_time_std"] = r.getTrainTimeStd();
j["test_time"] = r.getTestTime();
j["test_time_std"] = r.getTestTimeStd();
j["time"] = r.getTestTime() + r.getTrainTime();
j["time_std"] = r.getTestTimeStd() + r.getTrainTimeStd();
j["nodes"] = r.getNodes();
j["leaves"] = r.getLeaves();
j["depth"] = r.getDepth();
result["results"].push_back(j);
}
return result;
@@ -69,11 +88,16 @@ namespace platform {
);
auto Xt = torch::transpose(X, 0, 1);
auto result = Result();
auto [values, counts] = at::_unique(y);
result.setSamples(X.size(0)).setFeatures(X.size(1)).setClasses(values.size(0));
auto k = fold->getNumberOfFolds();
auto accuracy_test = torch::zeros({ k }, torch::kFloat64);
auto accuracy_train = torch::zeros({ k }, torch::kFloat64);
auto train_time = torch::zeros({ k }, torch::kFloat64);
auto test_time = torch::zeros({ k }, torch::kFloat64);
auto nodes = torch::zeros({ k }, torch::kFloat64);
auto edges = torch::zeros({ k }, torch::kFloat64);
auto num_states = torch::zeros({ k }, torch::kFloat64);
Timer train_timer, test_timer;
for (int i = 0; i < k; i++) {
bayesnet::BaseClassifier* model = classifiers[model_name];
@@ -86,6 +110,9 @@ namespace platform {
auto X_test = Xt.index({ "...", test_t });
auto y_test = y.index({ test_t });
model->fit(X_train, y_train, features, className, states);
nodes[i] = model->getNumberOfNodes();
edges[i] = model->getNumberOfEdges();
num_states[i] = model->getNumberOfStates();
cout << "Training Fold " << i + 1 << endl;
cout << "X_train: " << X_train.sizes() << endl;
cout << "y_train: " << y_train.sizes() << endl;
@@ -102,6 +129,7 @@ namespace platform {
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
return result;
}
}