Add subparser to be_report & tests

This commit is contained in:
2022-11-20 18:23:26 +01:00
parent 146304f4b5
commit 1b8a424ad3
7 changed files with 109 additions and 76 deletions

View File

@@ -49,10 +49,11 @@ class EnvDefault(argparse.Action):
class Arguments(argparse.ArgumentParser): class Arguments(argparse.ArgumentParser):
def __init__(self): def __init__(self, *args, **kwargs):
super().__init__() super().__init__(*args, **kwargs)
models_data = Models.define_models(random_state=0) models_data = Models.define_models(random_state=0)
self._overrides = {} self._overrides = {}
self._subparser = None
self.parameters = { self.parameters = {
"best": [ "best": [
("-b", "--best"), ("-b", "--best"),
@@ -100,19 +101,6 @@ class Arguments(argparse.ArgumentParser):
"help": "Generate Excel File", "help": "Generate Excel File",
}, },
], ],
"file": [
("-f", "--file"),
{"type": str, "required": False, "help": "Result file"},
],
"grid": [
("-g", "--grid"),
{
"action": "store_true",
"required": False,
"default": False,
"help": "grid results of model",
},
],
"grid_paramfile": [ "grid_paramfile": [
("-g", "--grid_paramfile"), ("-g", "--grid_paramfile"),
{ {
@@ -309,6 +297,23 @@ class Arguments(argparse.ArgumentParser):
) )
return self return self
def add_subparser(
self, dest="subcommand", help_text="help for subcommand"
):
self._subparser = self.add_subparsers(dest=dest, help=help_text)
def add_subparsers_options(self, subparser, arguments):
command, help_text = subparser
parser = self._subparser.add_parser(command, help=help_text)
for name, args in arguments:
try:
names, parameters = self.parameters[name]
except KeyError:
names = (name,)
parameters = {}
# Order of args is important
parser.add_argument(*names, **{**args, **parameters})
def parse(self, args=None): def parse(self, args=None):
for key, (dest_key, value) in self._overrides.items(): for key, (dest_key, value) in self._overrides.items():
if args is None: if args is None:

View File

@@ -251,7 +251,7 @@ class ReportBest(BaseReport):
"Hyperparameters", "Hyperparameters",
] ]
def __init__(self, score, model, best, grid): def __init__(self, score, model, best):
name = ( name = (
Files.best_results(score, model) Files.best_results(score, model)
if best if best
@@ -259,7 +259,6 @@ class ReportBest(BaseReport):
) )
file_name = os.path.join(Folders.results, name) file_name = os.path.join(Folders.results, name)
self.best = best self.best = best
self.grid = grid
self.score_name = score self.score_name = score
self.model = model self.model = model
super().__init__(file_name, best_file=True) super().__init__(file_name, best_file=True)

View File

@@ -21,5 +21,5 @@ def main(args_test=None):
print(e) print(e)
else: else:
if args.report: if args.report:
report = ReportBest(args.score, args.model, best=True, grid=False) report = ReportBest(args.score, args.model, best=True)
report.report() report.report()

View File

@@ -11,40 +11,71 @@ If no argument is set, displays the datasets and its characteristics
def main(args_test=None): def main(args_test=None):
arguments = Arguments() arguments = Arguments(prog="be_report")
arguments.xset("file").xset("excel").xset("sql").xset("compare") arguments.add_subparser()
arguments.xset("best").xset("grid").xset("model", required=False) arguments.add_subparsers_options(
arguments.xset("score", required=False) (
"best",
"Report best results obtained by any model/score. "
"See be_build_best",
),
[
("model", dict(required=False)),
("score", dict(required=False)),
],
)
arguments.add_subparsers_options(
(
"grid",
"Report grid results obtained by any model/score. "
"See be_build_grid",
),
[
("model", dict(required=False)),
("score", dict(required=False)),
],
)
arguments.add_subparsers_options(
("file", "Report file results"),
[
("file_name", {}),
("excel", {}),
("sql", {}),
("compare", {}),
],
)
arguments.add_subparsers_options(
("datasets", "Report datasets information"),
[
("excel", {}),
],
)
args = arguments.parse(args_test) args = arguments.parse(args_test)
if args.best: if args.subcommand == "best" or args.subcommand == "grid":
args.grid = False best = args.subcommand == "best"
if args.grid: report = ReportBest(args.score, args.model, best)
args.best = False report.report()
if args.file is None and not args.best and not args.grid: elif args.subcommand == "file":
try:
report = Report(args.file_name, args.compare)
report.report()
except FileNotFoundError as e:
print(e)
return
if args.sql:
sql = SQL(args.file_name)
sql.report()
if args.excel:
excel = Excel(
file_name=args.file_name,
compare=args.compare,
)
excel.report()
is_test = args_test is not None
Files.open(excel.get_file_name(), is_test)
else:
report = ReportDatasets(args.excel) report = ReportDatasets(args.excel)
report.report() report.report()
if args.excel: if args.excel:
is_test = args_test is not None is_test = args_test is not None
Files.open(report.get_file_name(), is_test) Files.open(report.get_file_name(), is_test)
else:
if args.best or args.grid:
report = ReportBest(args.score, args.model, args.best, args.grid)
report.report()
else:
try:
report = Report(args.file, args.compare)
except FileNotFoundError as e:
print(e)
else:
report.report()
if args.excel:
excel = Excel(
file_name=args.file,
compare=args.compare,
)
excel.report()
is_test = args_test is not None
Files.open(excel.get_file_name(), is_test)
if args.sql:
sql = SQL(args.file)
sql.report()

View File

@@ -24,13 +24,10 @@ class ArgumentsTest(TestBase):
def test_parameters(self): def test_parameters(self):
expected_parameters = { expected_parameters = {
"best": ("-b", "--best"),
"color": ("-c", "--color"), "color": ("-c", "--color"),
"compare": ("-c", "--compare"), "compare": ("-c", "--compare"),
"dataset": ("-d", "--dataset"), "dataset": ("-d", "--dataset"),
"excel": ("-x", "--excel"), "excel": ("-x", "--excel"),
"file": ("-f", "--file"),
"grid": ("-g", "--grid"),
"grid_paramfile": ("-g", "--grid_paramfile"), "grid_paramfile": ("-g", "--grid_paramfile"),
"hidden": ("--hidden",), "hidden": ("--hidden",),
"hyperparameters": ("-p", "--hyperparameters"), "hyperparameters": ("-p", "--hyperparameters"),

View File

@@ -69,13 +69,13 @@ class ReportTest(TestBase):
_ = Report("unknown_file") _ = Report("unknown_file")
def test_report_best(self): def test_report_best(self):
report = ReportBest("accuracy", "STree", best=True, grid=False) report = ReportBest("accuracy", "STree", best=True)
with patch(self.output, new=StringIO()) as stdout: with patch(self.output, new=StringIO()) as stdout:
report.report() report.report()
self.check_output_file(stdout, "report_best") self.check_output_file(stdout, "report_best")
def test_report_grid(self): def test_report_grid(self):
report = ReportBest("accuracy", "STree", best=False, grid=True) report = ReportBest("accuracy", "STree", best=False)
with patch(self.output, new=StringIO()) as stdout: with patch(self.output, new=StringIO()) as stdout:
report.report() report.report()
file_name = "report_grid.test" file_name = "report_grid.test"
@@ -90,12 +90,6 @@ class ReportTest(TestBase):
self.assertEqual(line, output_text[index]) self.assertEqual(line, output_text[index])
def test_report_best_both(self):
report = ReportBest("accuracy", "STree", best=True, grid=True)
with patch(self.output, new=StringIO()) as stdout:
report.report()
self.check_output_file(stdout, "report_best")
@patch("sys.stdout", new_callable=StringIO) @patch("sys.stdout", new_callable=StringIO)
def test_report_datasets(self, mock_output): def test_report_datasets(self, mock_output):
report = ReportDatasets() report = ReportDatasets()

View File

@@ -1,5 +1,7 @@
import os import os
from openpyxl import load_workbook from openpyxl import load_workbook
from io import StringIO
from unittest.mock import patch
from ...Utils import Folders, Files from ...Utils import Folders, Files
from ..TestBase import TestBase from ..TestBase import TestBase
from ..._version import __version__ from ..._version import __version__
@@ -23,25 +25,25 @@ class BeReportTest(TestBase):
"results", "results",
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json", "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
) )
stdout, stderr = self.execute_script("be_report", ["-f", file_name]) stdout, stderr = self.execute_script("be_report", ["file", file_name])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report") self.check_output_file(stdout, "report")
def test_be_report_not_found(self): def test_be_report_not_found(self):
stdout, stderr = self.execute_script("be_report", ["-f", "unknown"]) stdout, stderr = self.execute_script("be_report", ["file", "unknown"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "unknown does not exists!\n") self.assertEqual(stdout.getvalue(), "unknown does not exists!\n")
def test_be_report_compare(self): def test_be_report_compare(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-f", file_name, "-c"] "be_report", ["file", file_name, "-c"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_compared") self.check_output_file(stdout, "report_compared")
def test_be_report_datatsets(self): def test_be_report_datatsets(self):
stdout, stderr = self.execute_script("be_report", []) stdout, stderr = self.execute_script("be_report", ["datasets"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
file_name = f"report_datasets{self.ext}" file_name = f"report_datasets{self.ext}"
with open(os.path.join(self.test_files, file_name)) as f: with open(os.path.join(self.test_files, file_name)) as f:
@@ -54,7 +56,7 @@ class BeReportTest(TestBase):
self.assertEqual(line, output_text[index]) self.assertEqual(line, output_text[index])
def test_be_report_datasets_excel(self): def test_be_report_datasets_excel(self):
stdout, stderr = self.execute_script("be_report", ["-x"]) stdout, stderr = self.execute_script("be_report", ["datasets", "-x"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
file_name = f"report_datasets{self.ext}" file_name = f"report_datasets{self.ext}"
with open(os.path.join(self.test_files, file_name)) as f: with open(os.path.join(self.test_files, file_name)) as f:
@@ -77,14 +79,14 @@ class BeReportTest(TestBase):
def test_be_report_best(self): def test_be_report_best(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-b"] "be_report", ["best", "-s", "accuracy", "-m", "STree"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best") self.check_output_file(stdout, "report_best")
def test_be_report_grid(self): def test_be_report_grid(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-g"] "be_report", ["grid", "-s", "accuracy", "-m", "STree"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
file_name = "report_grid.test" file_name = "report_grid.test"
@@ -98,19 +100,24 @@ class BeReportTest(TestBase):
line = self.replace_STree_version(line, output_text, index) line = self.replace_STree_version(line, output_text, index)
self.assertEqual(line, output_text[index]) self.assertEqual(line, output_text[index])
def test_be_report_best_both(self): @patch("sys.stderr", new_callable=StringIO)
stdout, stderr = self.execute_script( def test_be_report_unknown_subcommand(self, stderr):
"be_report", with self.assertRaises(SystemExit) as msg:
["-s", "accuracy", "-m", "STree", "-b", "-g"], module = self.search_script("be_report")
module.main(["unknown", "accuracy", "-m", "STree"])
self.assertEqual(msg.exception.code, 2)
self.assertEqual(
stderr.getvalue(),
"usage: be_report [-h] {best,grid,file,datasets} ...\n"
"be_report: error: argument subcommand: invalid choice: "
"'unknown' (choose from 'best', 'grid', 'file', 'datasets')\n",
) )
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best")
def test_be_report_excel_compared(self): def test_be_report_excel_compared(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-x", "-c"], ["file", file_name, "-x", "-c"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx") Folders.results, file_name.replace(".json", ".xlsx")
@@ -125,7 +132,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-x"], ["file", file_name, "-x"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx") Folders.results, file_name.replace(".json", ".xlsx")
@@ -140,7 +147,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json" file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-q"], ["file", file_name, "-q"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".sql") Folders.results, file_name.replace(".json", ".sql")