Add screen width control in b_manage

This commit is contained in:
2024-07-15 18:06:39 +02:00
parent 2f2ed00ca1
commit f2556a30af
10 changed files with 108 additions and 76 deletions

View File

@@ -57,7 +57,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
target_link_libraries(b_main "${PyClassifiers}" "${BayesNet}" mdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy)
# b_manage
set(manage_sources ManageScreen.cpp CommandParser.cpp ResultsManager.cpp)
set(manage_sources ManageScreen.cpp OptionsMenu.cpp ResultsManager.cpp)
list(TRANSFORM manage_sources PREPEND manage/)
add_executable(
b_manage commands/b_manage.cpp ${manage_sources}

View File

@@ -12,6 +12,7 @@ namespace platform {
inline static const std::string downward_arrow{ "\u27B4" };
inline static const std::string up_arrow{ "\u2B06" };
inline static const std::string down_arrow{ "\u2B07" };
inline static const std::string ellipsis{ "\u2026" };
inline static const std::string equal_best{ check_mark };
inline static const std::string better_best{ black_star };
inline static const std::string notebook{ "\U0001F5C8" };

View File

@@ -54,6 +54,10 @@ namespace platform {
std::vector<double> aucScores(nClasses, 0.0);
std::vector<std::pair<double, int>> scoresAndLabels;
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
if (classIdx >= y_proba.size(1)) {
std::cerr << "AUC warning - class index out of range" << std::endl;
return 0;
}
scoresAndLabels.clear();
for (size_t i = 0; i < nSamples; ++i) {
scoresAndLabels.emplace_back(y_proba[i][classIdx].item<float>(), y_testv[i] == classIdx ? 1 : 0);

View File

@@ -1,21 +0,0 @@
#ifndef COMMAND_PARSER_H
#define COMMAND_PARSER_H
#include <string>
#include <vector>
#include <tuple>
namespace platform {
class CommandParser {
public:
CommandParser() = default;
std::tuple<char, int, bool> parse(const std::string& color, const std::vector<std::tuple<std::string, char, bool>>& options, const char defaultCommand, const int minIndex, const int maxIndex);
char getCommand() const { return command; };
int getIndex() const { return index; };
std::string getErrorMessage() const { return errorMessage; };
private:
std::string errorMessage;
char command;
int index;
};
} /* namespace platform */
#endif

View File

@@ -3,10 +3,9 @@
#include <string>
#include <algorithm>
#include "folding.hpp"
#include "common/Colors.h"
#include "common/CLocale.h"
#include "common/Paths.h"
#include "CommandParser.h"
#include "OptionsMenu.h"
#include "ManageScreen.h"
#include "reports/DatasetsConsole.h"
#include "reports/ReportConsole.h"
@@ -18,6 +17,7 @@
namespace platform {
const std::string STATUS_OK = "Ok.";
const std::string STATUS_COLOR = Colors::GREEN();
ManageScreen::ManageScreen(int rows, int cols, const std::string& model, const std::string& score, const std::string& platform, bool complete, bool partial, bool compare) :
rows{ rows }, cols{ cols }, complete{ complete }, partial{ partial }, compare{ compare }, didExcel(false), results(ResultsManager(model, score, platform, complete, partial))
{
@@ -25,7 +25,20 @@ namespace platform {
openExcel = false;
workbook = NULL;
this->rows = std::max(0, rows - 6); // 6 is the number of lines used by the menu & header
cols = std::max(cols, 140);
maxModel = results.maxModelSize();
maxTitle = results.maxTitleSize();
header_lengths = { 3, 10, maxModel, 11, 10, 12, 2, 3, 7, maxTitle };
header_labels = { " #", "Date", "Model", "Score Name", "Score", "Platform", "SD", "C/P", "Time", "Title" };
sort_fields = { "Date", "Model", "Score", "Time" };
int minTitle = 10;
// set 10 chars as minimum for Title
int columns = std::accumulate(header_lengths.begin(), header_lengths.end(), 0) + header_lengths.size() - maxTitle + minTitle;
if (columns > cols) {
throw std::runtime_error("Make screen bigger to fit the results! " + std::to_string(columns - cols) + " columns needed! ");
}
maxTitle = minTitle + cols - columns;
header_lengths[header_lengths.size() - 1] = maxTitle;
cols = std::min(cols, columns + maxTitle);
// Initializes the paginator for each output type (experiments, datasets, result)
for (int i = 0; i < static_cast<int>(OutputType::Count); i++) {
paginator.push_back(Paginator(this->rows, results.size()));
@@ -115,7 +128,6 @@ namespace platform {
}
void ManageScreen::list_result(const std::string& status_message, const std::string& status_color)
{
auto data = results.at(index).getJson();
ReportConsole report(data, compare);
auto header_text = report.getHeader();
@@ -140,11 +152,9 @@ namespace platform {
// Status Area
//
footer(status_message, status_color);
}
void ManageScreen::list_detail(const std::string& status_message, const std::string& status_color)
{
auto data = results.at(index).getJson();
ReportConsole report(data, compare, subIndex);
auto header_text = report.getHeader();
@@ -169,7 +179,6 @@ namespace platform {
// Status Area
//
footer(status_message, status_color);
}
void ManageScreen::list_datasets(const std::string& status_message, const std::string& status_color)
{
@@ -193,7 +202,6 @@ namespace platform {
// Status Area
//
footer(status_message, status_color);
}
void ManageScreen::list_experiments(const std::string& status_message, const std::string& status_color)
{
@@ -201,17 +209,10 @@ namespace platform {
// header
//
header();
//
// Field names
//
int maxModel = results.maxModelSize();
int maxTitle = results.maxTitleSize();
std::vector<int> header_lengths = { 3, 10, maxModel, 11, 10, 12, 2, 3, 7, maxTitle };
std::cout << Colors::RESET();
std::string arrow_dn = Symbols::down_arrow + " ";
std::string arrow_up = Symbols::up_arrow + " ";
std::vector<std::string> header_labels = { " #", "Date", "Model", "Score Name", "Score", "Platform", "SD", "C/P", "Time", "Title" };
std::vector<std::string> sort_fields = { "Date", "Model", "Score", "Time" };
for (int i = 0; i < header_labels.size(); i++) {
std::string suffix = "", color = Colors::GREEN();
int diff = 0;
@@ -234,7 +235,7 @@ namespace platform {
for (int i = index_from; i <= index_to; i++) {
auto color = (i % 2) ? Colors::BLUE() : Colors::CYAN();
std::cout << color << std::setw(3) << std::fixed << std::right << i << " ";
std::cout << results.at(i).to_string(maxModel) << std::endl;
std::cout << results.at(i).to_string(maxModel, maxTitle) << std::endl;
}
//
// Status Area
@@ -290,7 +291,6 @@ namespace platform {
std::pair<std::string, std::string> ManageScreen::sortList()
{
std::cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', time='t', model='m', ascending='+', descending='-'): ";
std::vector<std::string> fields = { "Date", "Model", "Score", "Time" };
std::string invalid_option = "Invalid sorting option";
std::string line;
char option;
@@ -322,7 +322,7 @@ namespace platform {
return { Colors::RED(), invalid_option };
}
results.sortResults(sort_field, sort_type);
return { Colors::GREEN(), "Sorted by " + fields[static_cast<int>(sort_field)] + " " + (sort_type == SortType::ASC ? "ascending" : "descending") };
return { Colors::GREEN(), "Sorted by " + sort_fields[static_cast<int>(sort_field)] + " " + (sort_type == SortType::ASC ? "ascending" : "descending") };
}
void ManageScreen::menu()
{
@@ -333,17 +333,17 @@ namespace platform {
std::vector<std::tuple<std::string, char, bool>> mainOptions = {
{"quit", 'q', false},
{"list", 'l', false},
{"delete", 'D', true},
{"Delete", 'D', true},
{"datasets", 'd', false},
{"hide", 'h', true},
{"sort", 's', false},
{"report", 'r', true},
{"excel", 'e', true},
{"title", 't', true},
{"set A", 'a', true},
{"set B", 'b', true},
{"set A", 'A', true},
{"set B", 'B', true},
{"compare A~B", 'c', false},
{"Page", 'p', true},
{"page", 'p', true},
{"Page+", '+', false },
{"Page-", '-', false}
};
@@ -354,23 +354,20 @@ namespace platform {
{"list", 'l', false},
{"excel", 'e', false},
{"back", 'b', false},
{"Page", 'p', true},
{"page", 'p', true},
{"Page+", '+', false},
{"Page-", '-', false}
};
auto parser = CommandParser();
auto main_menu = OptionsMenu(mainOptions, Colors::IGREEN(), Colors::YELLOW(), cols);
auto list_menu = OptionsMenu(listOptions, Colors::IBLUE(), Colors::YELLOW(), cols);
while (!finished) {
OptionsMenu& menu = output_type == OutputType::EXPERIMENTS ? main_menu : list_menu;
bool parserError = true; // force the first iteration
while (parserError) {
auto [min_index, max_index] = paginator[static_cast<int>(output_type)].getOffset();
if (output_type == OutputType::EXPERIMENTS) {
std::tie(option, index, parserError) = parser.parse(Colors::IGREEN(), mainOptions, 'r', min_index, max_index);
} else {
std::tie(option, subIndex, parserError) = parser.parse(Colors::IBLUE(), listOptions, 'r', min_index, max_index);
}
std::tie(option, index, parserError) = menu.parse('r', min_index, max_index);
if (parserError) {
list(parser.getErrorMessage(), Colors::RED());
list(menu.getErrorMessage(), Colors::RED());
}
}
switch (option) {
@@ -405,7 +402,7 @@ namespace platform {
case 'q':
finished = true;
break;
case 'a':
case 'A':
if (index == index_B) {
list("A and B cannot be the same!", Colors::RED());
break;
@@ -413,7 +410,7 @@ namespace platform {
index_A = index;
list("A set to " + std::to_string(index), Colors::GREEN());
break;
case 'b': // set_b or back to list
case 'B': // set_b or back to list
if (output_type == OutputType::EXPERIMENTS) {
if (index == index_A) {
list("A and B cannot be the same!", Colors::RED());

View File

@@ -2,6 +2,7 @@
#define MANAGE_SCREEN_H
#include <xlsxwriter.h>
#include "ResultsManager.h"
#include "common/Colors.h"
#include "Paginator.hpp"
namespace platform {
@@ -43,6 +44,10 @@ namespace platform {
bool complete;
bool partial;
bool compare;
int maxModel, maxTitle;
std::vector<std::string> header_labels;
std::vector<int> header_lengths;
std::vector<std::string> sort_fields;
SortField sort_field = SortField::DATE;
SortType sort_type = SortType::DESC;
std::vector<Paginator> paginator;

View File

@@ -1,30 +1,46 @@
#include "CommandParser.h"
#include "OptionsMenu.h"
#include <iostream>
#include <sstream>
#include <algorithm>
#include "common/Colors.h"
#include "common/Utils.h"
namespace platform {
std::tuple<char, int, bool> CommandParser::parse(const std::string& color, const std::vector<std::tuple<std::string, char, bool>>& options, const char defaultCommand, const int minIndex, const int maxIndex)
std::string OptionsMenu::to_string()
{
bool first = true;
size_t size = 0;
std::string result = color_normal + "Options: (";
for (auto& option : options) {
if (!first) {
result += ", ";
size += 2;
}
std::string title = std::get<0>(option);
auto pos = title.find(std::get<1>(option));
result += color_normal + title.substr(0, pos) + color_bold + title.substr(pos, 1) + color_normal + title.substr(pos + 1);
size += title.size();
first = false;
}
if (size + 3 > cols) { // 3 is the size of the "): " at the end
result = "";
first = true;
for (auto& option : options) {
if (!first) {
result += color_normal + ", ";
}
result += color_bold + std::get<1>(option);
first = false;
}
}
result += "): ";
return result;
}
std::tuple<char, int, bool> OptionsMenu::parse(char defaultCommand, int minIndex, int maxIndex)
{
bool finished = false;
while (!finished) {
std::stringstream oss;
std::cout << to_string();
std::string line;
oss << color << "Options (";
bool first = true;
for (auto& option : options) {
if (first) {
first = false;
} else {
oss << ", ";
}
oss << std::get<char>(option) << "=" << std::get<std::string>(option);
}
oss << "): ";
std::cout << oss.str();
getline(std::cin, line);
line = trim(line);
if (line.size() == 0) {

25
src/manage/OptionsMenu.h Normal file
View File

@@ -0,0 +1,25 @@
#ifndef OPTIONS_MENU_H
#define OPTIONS_MENU_H
#include <string>
#include <vector>
#include <tuple>
namespace platform {
class OptionsMenu {
public:
OptionsMenu(std::vector<std::tuple<std::string, char, bool>>& options, std::string color_normal, std::string color_bold, int cols) : options(options), color_normal(color_normal), color_bold(color_bold), cols(cols) {}
std::string to_string();
std::tuple<char, int, bool> parse(char defaultCommand, int minIndex, int maxIndex);
char getCommand() const { return command; };
int getIndex() const { return index; };
std::string getErrorMessage() const { return errorMessage; };
private:
std::vector<std::tuple<std::string, char, bool>>& options;
std::string color_normal, color_bold;
int cols;
std::string errorMessage;
char command;
int index;
};
} /* namespace platform */
#endif

View File

@@ -6,6 +6,7 @@
#include "common/DotEnv.h"
#include "common/CLocale.h"
#include "common/Paths.h"
#include "common/Symbols.h"
#include "Result.h"
namespace platform {
@@ -78,7 +79,7 @@ namespace platform {
}
std::string Result::to_string(int maxModel) const
std::string Result::to_string(int maxModel, int maxTitle) const
{
auto tmp = ConfigLocale();
std::stringstream oss;
@@ -97,7 +98,11 @@ namespace platform {
auto completeString = isComplete() ? "C" : "P";
oss << std::setw(1) << " " << completeString << " ";
oss << std::setw(5) << std::right << std::setprecision(2) << std::fixed << durationShow << " " << durationUnit << " ";
oss << std::setw(50) << std::left << data["title"].get<std::string>() << " ";
auto title = data["title"].get<std::string>();
if (title.size() > maxTitle) {
title = title.substr(0, maxTitle - 1) + Symbols::ellipsis;
}
oss << std::setw(maxTitle) << std::left << title;
return oss.str();
}
}

View File

@@ -18,7 +18,7 @@ namespace platform {
void save();
// Getters
json getJson();
std::string to_string(int maxModel) const;
std::string to_string(int maxModel, int maxTitle) const;
std::string getFilename() const;
std::string getDate() const { return data["date"].get<std::string>(); };
std::string getTime() const { return data["time"].get<std::string>(); };