From 53dafa340429730a736a0368bd49bea123234647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sat, 17 Feb 2024 19:09:43 +0100 Subject: [PATCH] Refactor Result & PartialResult classes Add title modification to b_manage --- src/Platform/CMakeLists.txt | 2 +- src/Platform/b_main.cc | 2 +- src/Platform/modules/BestResults.cc | 5 +- src/Platform/modules/Experiment.cc | 126 +++++--------------------- src/Platform/modules/Experiment.h | 97 ++++---------------- src/Platform/modules/ManageResults.cc | 20 +++- src/Platform/modules/Models.cc | 2 +- src/Platform/modules/PartialResult.h | 42 +++++++++ src/Platform/modules/Result.cc | 82 ++++++++++++----- src/Platform/modules/Result.h | 48 ++++++---- src/Platform/modules/Results.cc | 3 +- tests/TestResult.cc | 7 +- 12 files changed, 203 insertions(+), 233 deletions(-) create mode 100644 src/Platform/modules/PartialResult.h diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 7fad509..8b44388 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -22,7 +22,7 @@ set(grid_sources GridSearch.cc GridData.cc HyperParameters.cc Datasets.cc Datase list(TRANSFORM grid_sources PREPEND ${Platform_MODULES}) add_executable(b_grid b_grid.cc ${grid_sources}) add_executable(b_list b_list.cc ${Platform_MODULES}Datasets.cc ${Platform_MODULES}Dataset.cc) -set(main_sources Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc) +set(main_sources Experiment.cc Datasets.cc Dataset.cc Models.cc HyperParameters.cc ReportConsole.cc ReportBase.cc Result.cc) list(TRANSFORM main_sources PREPEND ${Platform_MODULES}) add_executable(b_main b_main.cc ${main_sources}) set(manage_sources Results.cc ManageResults.cc CommandParser.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) diff --git a/src/Platform/b_main.cc b/src/Platform/b_main.cc index 0f10a65..3c5c46f 100644 --- a/src/Platform/b_main.cc +++ b/src/Platform/b_main.cc @@ -128,7 +128,7 @@ int main(int argc, char** argv) experiment.go(filesToTest, quiet, no_train_score); experiment.setDuration(timer.getDuration()); if (saveResults) { - experiment.save(platform::Paths::results()); + experiment.saveResult(); } if (!quiet) experiment.report(); diff --git a/src/Platform/modules/BestResults.cc b/src/Platform/modules/BestResults.cc index e015675..66742fa 100644 --- a/src/Platform/modules/BestResults.cc +++ b/src/Platform/modules/BestResults.cc @@ -36,8 +36,9 @@ namespace platform { } json bests; for (const auto& file : files) { - auto result = Result(path, file); - auto data = result.load(); + auto result = Result(); + result.load(path, file); + auto data = result.getJson(); for (auto const& item : data.at("results")) { bool update = true; auto datasetName = item.at("dataset").get(); diff --git a/src/Platform/modules/Experiment.cc b/src/Platform/modules/Experiment.cc index 5f614e6..a3e555f 100644 --- a/src/Platform/modules/Experiment.cc +++ b/src/Platform/modules/Experiment.cc @@ -1,4 +1,3 @@ -#include #include "Experiment.h" #include "Datasets.h" #include "Models.h" @@ -6,109 +5,27 @@ #include "Paths.h" namespace platform { using json = nlohmann::json; - std::string get_date() - { - time_t rawtime; - tm* timeinfo; - time(&rawtime); - timeinfo = std::localtime(&rawtime); - std::ostringstream oss; - oss << std::put_time(timeinfo, "%Y-%m-%d"); - return oss.str(); - } - std::string get_time() - { - time_t rawtime; - tm* timeinfo; - time(&rawtime); - timeinfo = std::localtime(&rawtime); - std::ostringstream oss; - oss << std::put_time(timeinfo, "%H:%M:%S"); - return oss.str(); - } - std::string Experiment::get_file_name() - { - std::string result = "results_" + score_name + "_" + model + "_" + platform + "_" + get_date() + "_" + get_time() + "_" + (stratified ? "1" : "0") + ".json"; - return result; - } - json Experiment::build_json() + void Experiment::saveResult() { - json result; - result["title"] = title; - result["date"] = get_date(); - result["time"] = get_time(); - result["model"] = model; - result["version"] = model_version; - result["platform"] = platform; - result["score_name"] = score_name; - result["language"] = language; - result["language_version"] = language_version; - result["discretized"] = discretized; - result["stratified"] = stratified; - result["folds"] = nfolds; - result["seeds"] = randomSeeds; - result["duration"] = duration; - result["results"] = json::array(); - for (const auto& r : results) { - json j; - j["dataset"] = r.getDataset(); - j["hyperparameters"] = r.getHyperparameters(); - j["samples"] = r.getSamples(); - j["features"] = r.getFeatures(); - j["classes"] = r.getClasses(); - j["score_train"] = r.getScoreTrain(); - j["score_test"] = r.getScoreTest(); - j["score"] = r.getScoreTest(); - j["score_std"] = r.getScoreTestStd(); - j["score_train_std"] = r.getScoreTrainStd(); - j["score_test_std"] = r.getScoreTestStd(); - j["train_time"] = r.getTrainTime(); - j["train_time_std"] = r.getTrainTimeStd(); - j["test_time"] = r.getTestTime(); - j["test_time_std"] = r.getTestTimeStd(); - j["time"] = r.getTestTime() + r.getTrainTime(); - j["time_std"] = r.getTestTimeStd() + r.getTrainTimeStd(); - j["scores_train"] = r.getScoresTrain(); - j["scores_test"] = r.getScoresTest(); - j["times_train"] = r.getTimesTrain(); - j["times_test"] = r.getTimesTest(); - j["nodes"] = r.getNodes(); - j["leaves"] = r.getLeaves(); - j["depth"] = r.getDepth(); - j["notes"] = r.getNotes(); - result["results"].push_back(j); - } - return result; + result.save(); } - void Experiment::save(const std::string& path) - { - json data = build_json(); - ofstream file(path + "/" + get_file_name()); - file << data; - file.close(); - } - void Experiment::report() { - json data = build_json(); - ReportConsole report(data); + ReportConsole report(result.getJson()); report.show(); } - void Experiment::show() { - json data = build_json(); - std::cout << data.dump(4) << std::endl; + std::cout << result.getJson().dump(4) << std::endl; } - void Experiment::go(std::vector filesToProcess, bool quiet, bool no_train_score) { for (auto fileName : filesToProcess) { if (fileName.size() > max_name) max_name = fileName.size(); } - std::cout << Colors::MAGENTA() << "*** Starting experiment: " << title << " ***" << Colors::RESET() << std::endl << std::endl; + std::cout << Colors::MAGENTA() << "*** Starting experiment: " << result.getTitle() << " ***" << Colors::RESET() << std::endl << std::endl; if (!quiet) { std::cout << Colors::GREEN() << " Status Meaning" << std::endl; std::cout << " ------ -----------------------------" << Colors::RESET() << std::endl; @@ -130,7 +47,6 @@ namespace platform { if (!quiet) std::cout << std::endl; } - std::string getColor(bayesnet::status_t status) { switch (status) { @@ -163,11 +79,11 @@ namespace platform { if (!quiet) { std::cout << " " << setw(5) << samples << " " << setw(5) << features.size() << flush; } - // Prepare Resu lt - auto result = Result(); + // Prepare Result + auto partial_result = PartialResult(); auto [values, counts] = at::_unique(y); - result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); - result.setHyperparameters(hyperparameters.get(fileName)); + partial_result.setSamples(X.size(1)).setFeatures(X.size(0)).setClasses(values.size(0)); + partial_result.setHyperparameters(hyperparameters.get(fileName)); // Initialize results std::vectors int nResults = nfolds * static_cast(randomSeeds.size()); auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64); @@ -196,7 +112,7 @@ namespace platform { else fold = new folding::KFold(nfolds, y.size(0), seed); for (int nfold = 0; nfold < nfolds; nfold++) { - auto clf = Models::instance()->create(model); + auto clf = Models::instance()->create(result.getModel()); setModelVersion(clf->getVersion()); auto valid = clf->getValidHyperparameters(); hyperparameters.check(valid, fileName); @@ -238,22 +154,22 @@ namespace platform { if (!quiet) std::cout << "\b\b\b, " << flush; // Store results and times in std::vector - result.addScoreTrain(accuracy_train_value); - result.addScoreTest(accuracy_test_value); - result.addTimeTrain(train_time[item].item()); - result.addTimeTest(test_time[item].item()); + partial_result.addScoreTrain(accuracy_train_value); + partial_result.addScoreTest(accuracy_test_value); + partial_result.addTimeTrain(train_time[item].item()); + partial_result.addTimeTest(test_time[item].item()); item++; } if (!quiet) std::cout << "end. " << flush; delete fold; } - result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); - result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); - result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); - result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); - result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); - result.setDataset(fileName).setNotes(notes); - addResult(result); + partial_result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); + partial_result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); + partial_result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); + partial_result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); + partial_result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); + partial_result.setDataset(fileName).setNotes(notes); + addResult(partial_result); } } \ No newline at end of file diff --git a/src/Platform/modules/Experiment.h b/src/Platform/modules/Experiment.h index 53f7e85..f645b40 100644 --- a/src/Platform/modules/Experiment.h +++ b/src/Platform/modules/Experiment.h @@ -6,102 +6,41 @@ #include "folding.hpp" #include "BaseClassifier.h" #include "HyperParameters.h" -#include "TAN.h" -#include "KDB.h" -#include "AODE.h" -#include "Timer.h" +#include "Result.h" namespace platform { using json = nlohmann::json; - class Result { - private: - std::string dataset, model_version; - json hyperparameters; - int samples{ 0 }, features{ 0 }, classes{ 0 }; - double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 }; - float nodes{ 0 }, leaves{ 0 }, depth{ 0 }; - std::vector scores_train, scores_test, times_train, times_test; - std::vector notes; - public: - Result() = default; - Result& setDataset(const std::string& dataset) { this->dataset = dataset; return *this; } - Result& setNotes(const std::vector& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; } - Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; } - Result& setSamples(int samples) { this->samples = samples; return *this; } - Result& setFeatures(int features) { this->features = features; return *this; } - Result& setClasses(int classes) { this->classes = classes; return *this; } - Result& setScoreTrain(double score) { this->score_train = score; return *this; } - Result& setScoreTest(double score) { this->score_test = score; return *this; } - Result& setScoreTrainStd(double score_std) { this->score_train_std = score_std; return *this; } - Result& setScoreTestStd(double score_std) { this->score_test_std = score_std; return *this; } - Result& setTrainTime(double train_time) { this->train_time = train_time; return *this; } - Result& setTrainTimeStd(double train_time_std) { this->train_time_std = train_time_std; return *this; } - Result& setTestTime(double test_time) { this->test_time = test_time; return *this; } - Result& setTestTimeStd(double test_time_std) { this->test_time_std = test_time_std; return *this; } - Result& setNodes(float nodes) { this->nodes = nodes; return *this; } - Result& setLeaves(float leaves) { this->leaves = leaves; return *this; } - Result& setDepth(float depth) { this->depth = depth; return *this; } - Result& addScoreTrain(double score) { scores_train.push_back(score); return *this; } - Result& addScoreTest(double score) { scores_test.push_back(score); return *this; } - Result& addTimeTrain(double time) { times_train.push_back(time); return *this; } - Result& addTimeTest(double time) { times_test.push_back(time); return *this; } - const float get_score_train() const { return score_train; } - float get_score_test() { return score_test; } - const std::string& getDataset() const { return dataset; } - const json& getHyperparameters() const { return hyperparameters; } - const int getSamples() const { return samples; } - const int getFeatures() const { return features; } - const int getClasses() const { return classes; } - const double getScoreTrain() const { return score_train; } - const double getScoreTest() const { return score_test; } - const double getScoreTrainStd() const { return score_train_std; } - const double getScoreTestStd() const { return score_test_std; } - const double getTrainTime() const { return train_time; } - const double getTrainTimeStd() const { return train_time_std; } - const double getTestTime() const { return test_time; } - const double getTestTimeStd() const { return test_time_std; } - const float getNodes() const { return nodes; } - const float getLeaves() const { return leaves; } - const float getDepth() const { return depth; } - const std::vector& getNotes() const { return notes; } - const std::vector& getScoresTrain() const { return scores_train; } - const std::vector& getScoresTest() const { return scores_test; } - const std::vector& getTimesTrain() const { return times_train; } - const std::vector& getTimesTest() const { return times_test; } - }; + class Experiment { public: Experiment() = default; - Experiment& setTitle(const std::string& title) { this->title = title; return *this; } - Experiment& setModel(const std::string& model) { this->model = model; return *this; } - Experiment& setPlatform(const std::string& platform) { this->platform = platform; return *this; } - Experiment& setScoreName(const std::string& score_name) { this->score_name = score_name; return *this; } - Experiment& setModelVersion(const std::string& model_version) { this->model_version = model_version; return *this; } - Experiment& setLanguage(const std::string& language) { this->language = language; return *this; } - Experiment& setLanguageVersion(const std::string& language_version) { this->language_version = language_version; return *this; } - Experiment& setDiscretized(bool discretized) { this->discretized = discretized; return *this; } - Experiment& setStratified(bool stratified) { this->stratified = stratified; return *this; } - Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; return *this; } - Experiment& addResult(Result result) { results.push_back(result); return *this; } - Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; } - Experiment& setDuration(float duration) { this->duration = duration; return *this; } + Experiment& setPlatform(const std::string& platform) { this->result.setPlatform(platform); return *this; } + Experiment& setScoreName(const std::string& score_name) { this->result.setScoreName(score_name); return *this; } + Experiment& setTitle(const std::string& title) { this->result.setTitle(title); return *this; } + Experiment& setModelVersion(const std::string& model_version) { this->result.setModelVersion(model_version); return *this; } + Experiment& setModel(const std::string& model) { this->result.setModel(model); return *this; } + Experiment& setLanguage(const std::string& language) { this->result.setLanguage(language); return *this; } + Experiment& setLanguageVersion(const std::string& language_version) { this->result.setLanguageVersion(language_version); return *this; } + Experiment& setDiscretized(bool discretized) { this->discretized = discretized; result.setDiscretized(discretized); return *this; } + Experiment& setStratified(bool stratified) { this->stratified = stratified; result.setStratified(stratified); return *this; } + Experiment& setNFolds(int nfolds) { this->nfolds = nfolds; result.setNFolds(nfolds); return *this; } + Experiment& addResult(PartialResult result_) { result.addPartial(result_); return *this; } + Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); result.addSeed(randomSeed); return *this; } + Experiment& setDuration(float duration) { this->result.setDuration(duration); return *this; } Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; } - std::string get_file_name(); - void save(const std::string& path); void cross_validation(const std::string& fileName, bool quiet, bool no_train_score); void go(std::vector filesToProcess, bool quiet, bool no_train_score); + void saveResult(); void show(); void report(); private: - std::string title, model, platform, score_name, model_version, language_version, language; + Result result; bool discretized{ false }, stratified{ false }; - std::vector results; + std::vector results; std::vector randomSeeds; HyperParameters hyperparameters; int nfolds{ 0 }; int max_name{ 7 }; // max length of dataset name for formatting (default 7) - float duration{ 0 }; - json build_json(); }; } #endif \ No newline at end of file diff --git a/src/Platform/modules/ManageResults.cc b/src/Platform/modules/ManageResults.cc index dc03979..196fe3c 100644 --- a/src/Platform/modules/ManageResults.cc +++ b/src/Platform/modules/ManageResults.cc @@ -87,7 +87,7 @@ namespace platform { void ManageResults::report(const int index, const bool excelReport) { std::cout << Colors::YELLOW() << "Reporting " << results.at(index).getFilename() << std::endl; - auto data = results.at(index).load(); + auto data = results.at(index).getJson(); if (excelReport) { ReportExcel reporter(data, compare, workbook); reporter.show(); @@ -102,7 +102,7 @@ namespace platform { void ManageResults::showIndex(const int index, const int idx) { // Show a dataset result inside a report - auto data = results.at(index).load(); + auto data = results.at(index).getJson(); std::cout << Colors::YELLOW() << "Showing " << results.at(index).getFilename() << std::endl; ReportConsole reporter(data, compare, idx); reporter.show(); @@ -151,7 +151,8 @@ namespace platform { {"hide", 'h', true}, {"sort", 's', false}, {"report", 'r', true}, - {"excel", 'e', true} + {"excel", 'e', true}, + {"title", 't', true} }; std::vector> listOptions = { {"report", 'r', true}, @@ -163,7 +164,7 @@ namespace platform { if (indexList) { std::tie(option, index) = parser.parse(Colors::GREEN(), mainOptions, 'r', numFiles - 1); } else { - std::tie(option, subIndex) = parser.parse(Colors::MAGENTA(), listOptions, 'r', results.at(index).load()["results"].size() - 1); + std::tie(option, subIndex) = parser.parse(Colors::MAGENTA(), listOptions, 'r', results.at(index).getJson()["results"].size() - 1); } switch (option) { case 'q': @@ -207,6 +208,17 @@ namespace platform { case 'e': report(index, true); break; + case 't': + std::cout << "Title: " << results.at(index).getTitle() << std::endl; + std::cout << "New title: "; + std::string newTitle; + getline(std::cin, newTitle); + if (!newTitle.empty()) { + results.at(index).setTitle(newTitle); + results.at(index).save(); + std::cout << "Title changed to " << newTitle << std::endl; + } + break; } } } diff --git a/src/Platform/modules/Models.cc b/src/Platform/modules/Models.cc index f73c0d6..10929e4 100644 --- a/src/Platform/modules/Models.cc +++ b/src/Platform/modules/Models.cc @@ -27,7 +27,7 @@ namespace platform { if (instance != nullptr) return unique_ptr(instance); else - return nullptr; + throw std::runtime_error("Model not found: " + name); } std::vector Models::getNames() { diff --git a/src/Platform/modules/PartialResult.h b/src/Platform/modules/PartialResult.h new file mode 100644 index 0000000..51a04c0 --- /dev/null +++ b/src/Platform/modules/PartialResult.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +using json = nlohmann::json; + +class PartialResult { + +public: + PartialResult() { data["scores_train"] = json::array(); data["scores_test"] = json::array(); data["times_train"] = json::array(); data["times_test"] = json::array(); }; + PartialResult& setDataset(const std::string& dataset) { data["dataset"] = dataset; return *this; } + PartialResult& setNotes(const std::vector& notes) { this->notes.insert(this->notes.end(), notes.begin(), notes.end()); return *this; } + PartialResult& setHyperparameters(const json& hyperparameters) { data["hyperparameters"] = hyperparameters; return *this; } + PartialResult& setSamples(int samples) { data["samples"] = samples; return *this; } + PartialResult& setFeatures(int features) { data["features"] = features; return *this; } + PartialResult& setClasses(int classes) { data["classes"] = classes; return *this; } + PartialResult& setScoreTrain(double score) { data["score_train"] = score; return *this; } + PartialResult& setScoreTrainStd(double score_std) { data["score_train_std"] = score_std; return *this; } + PartialResult& setScoreTest(double score) { data["score"] = score; return *this; } + PartialResult& setScoreTestStd(double score_std) { data["score_std"] = score_std; return *this; } + PartialResult& setTrainTime(double train_time) { data["train_time"] = train_time; return *this; } + PartialResult& setTrainTimeStd(double train_time_std) { data["train_time_std"] = train_time_std; return *this; } + PartialResult& setTestTime(double test_time) { data["test_time"] = test_time; return *this; } + PartialResult& setTestTimeStd(double test_time_std) { data["test_time_std"] = test_time_std; return *this; } + PartialResult& setNodes(float nodes) { data["nodes"] = nodes; return *this; } + PartialResult& setLeaves(float leaves) { data["leaves"] = leaves; return *this; } + PartialResult& setDepth(float depth) { data["depth"] = depth; return *this; } + PartialResult& addScoreTrain(double score) { data["scores_train"].push_back(score); return *this; } + PartialResult& addScoreTest(double score) { data["scores_test"].push_back(score); return *this; } + PartialResult& addTimeTrain(double time) { data["times_train"].push_back(time); return *this; } + PartialResult& addTimeTest(double time) { data["times_test"].push_back(time); return *this; } + json getJson() + { + data["time"] = data["test_time"].get() + data["train_time"].get(); + data["time_std"] = data["test_time_std"].get() + data["train_time_std"].get(); + data["notes"] = notes; + return data; + } +private: + json data; + std::vector notes; +}; \ No newline at end of file diff --git a/src/Platform/modules/Result.cc b/src/Platform/modules/Result.cc index 43c33d1..ec5f584 100644 --- a/src/Platform/modules/Result.cc +++ b/src/Platform/modules/Result.cc @@ -1,58 +1,98 @@ #include "Result.h" -#include "BestScore.h" #include #include #include +#include "BestScore.h" #include "Colors.h" #include "DotEnv.h" #include "CLocale.h" +#include "Paths.h" namespace platform { - Result::Result(const std::string& path, const std::string& filename) - : path(path) - , filename(filename) + std::string get_actual_date() { - auto data = load(); - date = data["date"]; + time_t rawtime; + tm* timeinfo; + time(&rawtime); + timeinfo = std::localtime(&rawtime); + std::ostringstream oss; + oss << std::put_time(timeinfo, "%Y-%m-%d"); + return oss.str(); + } + std::string get_actual_time() + { + time_t rawtime; + tm* timeinfo; + time(&rawtime); + timeinfo = std::localtime(&rawtime); + std::ostringstream oss; + oss << std::put_time(timeinfo, "%H:%M:%S"); + return oss.str(); + } + Result::Result() + { + data["date"] = get_actual_date(); + data["time"] = get_actual_time(); + data["results"] = json::array(); + data["seeds"] = json::array(); + } + + Result& Result::load(const std::string& path, const std::string& fileName) + { + std::ifstream resultData(path + "/" + fileName); + if (resultData.is_open()) { + data = json::parse(resultData); + } else { + throw std::invalid_argument("Unable to open result file. [" + path + "/" + fileName + "]"); + } score = 0; for (const auto& result : data["results"]) { score += result["score"].get(); } - scoreName = data["score_name"]; + auto scoreName = data["score_name"]; auto best = BestScore::getScore(scoreName); if (best.first != "") { score /= best.second; } - title = data["title"]; - duration = data["duration"]; - model = data["model"]; complete = data["results"].size() > 1; + return *this; + } + json Result::getJson() + { + return data; } - json Result::load() const + void Result::save() { - std::ifstream resultData(path + "/" + filename); - if (resultData.is_open()) { - json data = json::parse(resultData); - return data; - } - throw std::invalid_argument("Unable to open result file. [" + path + "/" + filename + "]"); + std::ofstream file(Paths::results() + "/" + getFilename()); + file << data; + file.close(); } + std::string Result::getFilename() const + { + std::ostringstream oss; + oss << "results_" << data.at("score_name").get() << "_" << data.at("model").get() << "_" + << data.at("platform").get() << "_" << data["date"].get() << "_" + << data["time"].get() << "_" << (data["stratified"] ? "1" : "0") << ".json"; + return oss.str(); + } + std::string Result::to_string(int maxModel) const { auto tmp = ConfigLocale(); std::stringstream oss; + auto duration = data["duration"].get(); double durationShow = duration > 3600 ? duration / 3600 : duration > 60 ? duration / 60 : duration; std::string durationUnit = duration > 3600 ? "h" : duration > 60 ? "m" : "s"; - oss << date << " "; - oss << std::setw(maxModel) << std::left << model << " "; - oss << std::setw(11) << std::left << scoreName << " "; + oss << data["date"].get() << " "; + oss << std::setw(maxModel) << std::left << data["model"].get() << " "; + oss << std::setw(11) << std::left << data["score_name"].get() << " "; oss << std::right << std::setw(11) << std::setprecision(7) << std::fixed << score << " "; auto completeString = isComplete() ? "C" : "P"; oss << std::setw(1) << " " << completeString << " "; oss << std::setw(7) << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit << " "; - oss << std::setw(50) << std::left << title << " "; + oss << std::setw(50) << std::left << data["title"].get() << " "; return oss.str(); } } \ No newline at end of file diff --git a/src/Platform/modules/Result.h b/src/Platform/modules/Result.h index 10459b7..6e1d4bd 100644 --- a/src/Platform/modules/Result.h +++ b/src/Platform/modules/Result.h @@ -4,32 +4,48 @@ #include #include #include +#include "HyperParameters.h" +#include "PartialResult.h" +#include "Timer.h" + namespace platform { using json = nlohmann::json; class Result { public: - Result(const std::string& path, const std::string& filename); - json load() const; + Result(); + Result& load(const std::string& path, const std::string& filename); + void save(); + // Getters + json getJson(); std::string to_string(int maxModel) const; - std::string getFilename() const { return filename; }; - std::string getDate() const { return date; }; + std::string getFilename() const; + std::string getDate() const { return data["date"].get(); }; double getScore() const { return score; }; - std::string getTitle() const { return title; }; - double getDuration() const { return duration; }; - std::string getModel() const { return model; }; - std::string getScoreName() const { return scoreName; }; + std::string getTitle() const { return data["title"].get(); }; + double getDuration() const { return data["duration"]; }; + std::string getModel() const { return data["model"].get(); }; + std::string getScoreName() const { return data["score_name"].get(); }; bool isComplete() const { return complete; }; + // Setters + void setTitle(const std::string& title) { data["title"] = title; }; + void setLanguage(const std::string& language) { data["language"] = language; }; + void setLanguageVersion(const std::string& language_version) { data["language_version"] = language_version; }; + void setDuration(double duration) { data["duration"] = duration; }; + void setModel(const std::string& model) { data["model"] = model; }; + void setModelVersion(const std::string& model_version) { data["version"] = model_version; }; + void setScoreName(const std::string& scoreName) { data["score_name"] = scoreName; }; + void setDiscretized(bool discretized) { data["discretized"] = discretized; }; + void addSeed(int seed) { data["seeds"].push_back(seed); }; + void addPartial(PartialResult& partial_result) { data["results"].push_back(partial_result.getJson()); }; + void setStratified(bool stratified) { data["stratified"] = stratified; }; + void setNFolds(int nfolds) { data["folds"] = nfolds; }; + void setPlatform(const std::string& platform_name) { data["platform"] = platform_name; }; + private: - std::string path; - std::string filename; - std::string date; - double score; - std::string title; - double duration; - std::string model; - std::string scoreName; + json data; bool complete; + double score = 0.0; }; }; #endif \ No newline at end of file diff --git a/src/Platform/modules/Results.cc b/src/Platform/modules/Results.cc index 6f5fe14..cb7c41a 100644 --- a/src/Platform/modules/Results.cc +++ b/src/Platform/modules/Results.cc @@ -18,7 +18,8 @@ namespace platform { for (const auto& file : directory_iterator(path)) { auto filename = file.path().filename().string(); if (filename.find(".json") != std::string::npos && filename.find("results_") == 0) { - auto result = Result(path, filename); + auto result = Result(); + result.load(path, filename); bool addResult = true; if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName() || complete && !result.isComplete() || partial && result.isComplete()) addResult = false; diff --git a/tests/TestResult.cc b/tests/TestResult.cc index f03cba7..05e2622 100644 --- a/tests/TestResult.cc +++ b/tests/TestResult.cc @@ -10,7 +10,8 @@ TEST_CASE("Result class tests", "[Result]") SECTION("Constructor and load method") { - platform::Result result(testPath, testFile); + platform::Result result; + result.load(testPath, testFile); REQUIRE(result.date != ""); REQUIRE(result.score >= 0); REQUIRE(result.scoreName != ""); @@ -22,6 +23,7 @@ TEST_CASE("Result class tests", "[Result]") SECTION("to_string method") { platform::Result result(testPath, testFile); + result.load(); std::string resultStr = result.to_string(1); REQUIRE(resultStr != ""); } @@ -29,6 +31,7 @@ TEST_CASE("Result class tests", "[Result]") SECTION("Exception handling in load method") { std::string invalidFile = "invalid.json"; - REQUIRE_THROWS_AS(platform::Result(testPath, invalidFile), std::invalid_argument); + auto result = platform::Result(); + REQUIRE_THROWS_AS(platform::result.load(testPath, invalidFile), std::invalid_argument); } } \ No newline at end of file