Add show experiment

This commit is contained in:
2023-07-29 16:31:36 +02:00
parent 85cb447283
commit b9e76becce
3 changed files with 31 additions and 10 deletions

View File

@@ -43,7 +43,7 @@ namespace platform {
result["discretized"] = discretized;
result["stratified"] = stratified;
result["folds"] = nfolds;
result["seeds"] = random_seeds;
result["seeds"] = randomSeeds;
result["duration"] = duration;
result["results"] = json::array();
for (auto& r : results) {
@@ -65,6 +65,10 @@ namespace platform {
j["test_time_std"] = r.getTestTimeStd();
j["time"] = r.getTestTime() + r.getTrainTime();
j["time_std"] = r.getTestTimeStd() + r.getTrainTimeStd();
j["scores_train"] = r.getScoresTrain();
j["scores_test"] = r.getScoresTest();
j["times_train"] = r.getTimesTrain();
j["times_test"] = r.getTimesTest();
j["nodes"] = r.getNodes();
j["leaves"] = r.getLeaves();
j["depth"] = r.getDepth();
@@ -79,6 +83,11 @@ namespace platform {
file << data;
file.close();
}
void Experiment::show()
{
json data = build_json();
cout << data.dump(4) << endl;
}
Result cross_validation(Fold* fold, string model_name, torch::Tensor& Xt, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states)
{
auto classifiers = map<string, bayesnet::BaseClassifier*>({
@@ -101,7 +110,6 @@ namespace platform {
cout << "doing Fold: " << flush;
for (int i = 0; i < k; i++) {
bayesnet::BaseClassifier* model = classifiers[model_name];
result.setModelVersion(model->getVersion());
train_timer.start();
auto [train, test] = fold->getFold(i);
auto train_t = torch::tensor(train);
@@ -122,8 +130,9 @@ namespace platform {
test_time[i] = test_timer.getDuration();
accuracy_train[i] = accuracy_train_value;
accuracy_test[i] = accuracy_test_value;
}
cout << "end." << endl;
cout << "end. " << flush;
result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());