Add results to b_list

This commit is contained in:
2024-03-10 18:02:03 +01:00
parent 05d05e25c2
commit cd9ff89b52
9 changed files with 184 additions and 27 deletions

View File

@@ -0,0 +1,54 @@
#include <algorithm>
#include "common/Paths.h"
#include "ResultsDataset.h"
namespace platform {
ResultsDataset::ResultsDataset(const std::string& dataset, const std::string& model, const std::string& score) :
path(Paths::results()), dataset(dataset), model(model), scoreName(score), maxModel(0), maxFile(0), maxHyper(15), maxResult(0)
{
}
void ResultsDataset::load()
{
using std::filesystem::directory_iterator;
for (const auto& file : directory_iterator(path)) {
auto filename = file.path().filename().string();
if (filename.find(".json") != std::string::npos && filename.find("results_") == 0) {
auto result = Result();
result.load(path, filename);
if (model != "any" && result.getModel() != model)
continue;
auto data = result.getData()["results"];
for (auto const& item : data) {
if (item["dataset"] == dataset) {
auto hyper_length = item["hyperparameters"].dump().size();
if (hyper_length > maxHyper)
maxHyper = hyper_length;
if (item["score"].get<double>() > maxResult)
maxResult = item["score"].get<double>();
files.push_back(result);
break;
}
}
}
}
maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size());
maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size());
}
int ResultsDataset::size() const
{
return files.size();
}
void ResultsDataset::sortModel()
{
sort(files.begin(), files.end(), [](const Result& a, const Result& b) {
if (a.getModel() == b.getModel()) {
return a.getDate() > b.getDate();
}
return a.getModel() < b.getModel();
});
}
bool ResultsDataset::empty() const
{
return files.empty();
}
}

34
src/list/ResultsDataset.h Normal file
View File

@@ -0,0 +1,34 @@
#pragma once
#include <vector>
#include <string>
#include <nlohmann/json.hpp>
#include "main/Result.h"
namespace platform {
using json = nlohmann::json;
class ResultsDataset {
public:
ResultsDataset(const std::string& dataset, const std::string& model, const std::string& score);
void load(); // Loads the list of results
void sortModel();
int maxModelSize() const { return maxModel; };
int maxFileSize() const { return maxFile; };
int maxHyperSize() const { return maxHyper; };
double maxResultScore() const { return maxResult; };
int size() const;
bool empty() const;
std::vector<Result>::iterator begin() { return files.begin(); };
std::vector<Result>::iterator end() { return files.end(); };
Result& at(int index) { return files.at(index); };
private:
std::string path;
std::string dataset;
std::string model;
std::string scoreName;
int maxModel;
int maxFile;
int maxHyper;
double maxResult;
std::vector<Result> files;
};
};

View File

@@ -3,10 +3,13 @@
#include <map>
#include <argparse/argparse.hpp>
#include <nlohmann/json.hpp>
#include "main/Models.h"
#include "main/modelRegister.h"
#include "common/Paths.h"
#include "common/Colors.h"
#include "common/Datasets.h"
#include "DatasetsExcel.h"
#include "ResultsDataset.h"
#include "config.h"
const int BALANCE_LENGTH = 75;
@@ -32,7 +35,7 @@ std::string outputBalance(const std::string& balance)
void list_datasets(argparse::ArgumentParser& program)
{
auto datasets = platform::Datasets(false, platform::Paths::datasets());
auto excel = program.get<bool>("--excel");
auto excel = program.get<bool>("excel");
locale mylocale(std::cout.getloc(), new separated);
locale::global(mylocale);
std::cout.imbue(mylocale);
@@ -74,10 +77,44 @@ void list_datasets(argparse::ArgumentParser& program)
void list_results(argparse::ArgumentParser& program)
{
std::cout << "Results" << std::endl;
auto dataset = program.get<string>("--dataset");
auto score = program.get<string>("--score");
auto dataset = program.get<string>("dataset");
auto score = program.get<string>("score");
auto model = program.get<string>("model");
auto results = platform::ResultsDataset(dataset, model, score);
results.load();
results.sortModel();
if (results.empty()) {
std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl;
exit(1);
}
//
// List data
//
int maxModel = results.maxModelSize();
int maxFileName = results.maxFileSize();
int maxHyper = results.maxHyperSize();
double maxResult = results.maxResultScore();
std::cout << Colors::GREEN() << "Results for dataset " << dataset << std::endl;
std::cout << "There are " << results.size() << " results" << std::endl;
std::cout << Colors::GREEN() << " # " << std::setw(maxModel + 1) << std::left << "Model" << "Date Score " << std::setw(maxFileName) << "File" << " Hyperparameters" << std::endl;
std::cout << "=== " << std::string(maxModel, '=') << " ========== =========== " << std::string(maxFileName, '=') << " " << std::string(maxHyper, '=') << std::endl;
auto i = 0;
for (const auto& result : results) {
auto data = result.getData();
for (const auto& item : data["results"]) {
if (item["dataset"] == dataset) {
auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN();
color = item["score"].get<double>() == maxResult ? Colors::RED() : color;
std::cout << color << std::setw(3) << std::fixed << std::right << i++ << " ";
std::cout << std::setw(maxModel) << std::left << result.getModel() << " ";
std::cout << color << result.getDate() << " ";
std::cout << std::setw(11) << std::setprecision(9) << std::fixed << item["score"].get<double>() << " ";
std::cout << std::setw(maxFileName) << result.getFilename() << " ";
std::cout << item["hyperparameters"].dump() << std::endl;
break;
}
}
}
}
int main(int argc, char** argv)
@@ -110,7 +147,20 @@ int main(int argc, char** argv)
throw std::runtime_error("Dataset must be one of " + datasets.toString());
}
);
program.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
results_command.add_argument("-m", "--model")
.help("Model to use: " + platform::Models::instance()->toString() + " or any")
.default_value("any")
.action([](const std::string& value) {
std::vector<std::string> valid(platform::Models::instance()->getNames());
valid.push_back("any");
static const std::vector<std::string> choices = valid;
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString() + " or any");
}
);
results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
// Add subparsers
program.add_subparser(datasets_command);
program.add_subparser(results_command);