Add overrides to args parse for dataset/title in be_main

This commit is contained in:
2022-11-19 21:16:29 +01:00
parent 68d9cb776e
commit 07172b91c5
12 changed files with 92 additions and 59 deletions

View File

@@ -36,6 +36,7 @@ class EnvDefault(argparse.Action):
self, envvar, required=True, default=None, mandatory=False, **kwargs
):
self._args = EnvData.load()
self._overrides = {}
if required and not mandatory:
default = self._args[envvar]
required = False
@@ -51,20 +52,22 @@ class Arguments:
def __init__(self):
self.ap = argparse.ArgumentParser()
models_data = Models.define_models(random_state=0)
self._overrides = {}
self.parameters = {
"best": [
("-b", "--best"),
{
"type": str,
"required": False,
"action": "store_true",
"default": False,
"help": "best results of models",
},
],
"color": [
("-c", "--color"),
{
"type": bool,
"required": False,
"action": "store_true",
"default": False,
"help": "use colors for the tree",
},
@@ -72,8 +75,9 @@ class Arguments:
"compare": [
("-c", "--compare"),
{
"type": bool,
"action": "store_true",
"required": False,
"default": False,
"help": "Compare accuracy with best results",
},
],
@@ -81,6 +85,8 @@ class Arguments:
("-d", "--dataset"),
{
"type": str,
"envvar": "dataset", # for compatiblity with EnvDefault
"action": EnvDefault,
"required": False,
"help": "dataset to work with",
},
@@ -88,8 +94,8 @@ class Arguments:
"excel": [
("-x", "--excel"),
{
"type": bool,
"required": False,
"action": "store_true",
"default": False,
"help": "Generate Excel File",
},
@@ -101,16 +107,17 @@ class Arguments:
"grid": [
("-g", "--grid"),
{
"type": str,
"action": "store_true",
"required": False,
"default": False,
"help": "grid results of model",
},
],
"grid_paramfile": [
("-g", "--grid_paramfile"),
{
"type": bool,
"required": False,
"action": "store_true",
"default": False,
"help": "Use best hyperparams file?",
},
@@ -118,8 +125,8 @@ class Arguments:
"hidden": [
("--hidden",),
{
"type": str,
"required": False,
"action": "store_true",
"default": False,
"help": "Show hidden results",
},
@@ -140,8 +147,8 @@ class Arguments:
"lose": [
("-l", "--lose"),
{
"type": bool,
"default": False,
"action": "store_true",
"required": False,
"help": "show lose results",
},
@@ -178,8 +185,9 @@ class Arguments:
"nan": [
("--nan",),
{
"type": bool,
"action": "store_true",
"required": False,
"default": False,
"help": "Move nan results to hidden folder",
},
],
@@ -205,7 +213,7 @@ class Arguments:
"paramfile": [
("-f", "--paramfile"),
{
"type": bool,
"action": "store_true",
"required": False,
"default": False,
"help": "Use best hyperparams file?",
@@ -224,7 +232,7 @@ class Arguments:
"quiet": [
("-q", "--quiet"),
{
"type": bool,
"action": "store_true",
"required": False,
"default": False,
},
@@ -232,7 +240,7 @@ class Arguments:
"report": [
("-r", "--report"),
{
"type": bool,
"action": "store_true",
"default": False,
"required": False,
"help": "Report results",
@@ -250,23 +258,27 @@ class Arguments:
],
"sql": [
("-q", "--sql"),
{"type": bool, "required": False, "help": "Generate SQL File"},
{
"required": False,
"action": "store_true",
"default": False,
"help": "Generate SQL File",
},
],
"stratified": [
("-t", "--stratified"),
{
"action": EnvDefault,
"envvar": "stratified",
"type": str,
"required": True,
"required": False,
"help": "Stratified",
},
],
"tex_output": [
("-t", "--tex-output"),
{
"type": bool,
"required": False,
"action": "store_true",
"default": False,
"help": "Generate Tex file with the table",
},
@@ -278,8 +290,8 @@ class Arguments:
"win": [
("-w", "--win"),
{
"type": bool,
"default": False,
"action": "store_true",
"required": False,
"help": "show win results",
},
@@ -287,12 +299,20 @@ class Arguments:
}
def xset(self, *arg_name, **kwargs):
names, default = self.parameters[arg_name[0]]
names, parameters = self.parameters[arg_name[0]]
if "overrides" in kwargs:
self._overrides[names[0]] = (kwargs["overrides"], kwargs["const"])
del kwargs["overrides"]
self.ap.add_argument(
*names,
**{**default, **kwargs},
**{**parameters, **kwargs},
)
return self
def parse(self, args=None):
for key, (dest_key, value) in self._overrides.items():
if args is None:
args = sys.argv[1:]
if key in args:
args.extend((f"--{dest_key}", value))
return self.ap.parse_args(args)

View File

@@ -14,7 +14,10 @@ def main(args_test=None):
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("hyperparameters").xset("paramfile").xset("report")
arguments.xset("grid_paramfile").xset("dataset")
arguments.xset("grid_paramfile")
arguments.xset(
"dataset", overrides="title", const="Test with only one dataset"
)
args = arguments.parse(args_test)
report = args.report or args.dataset is not None
if args.grid_paramfile:

View File

@@ -17,17 +17,17 @@ def main(args_test=None):
arguments.xset("score", required=False)
args = arguments.parse(args_test)
if args.best:
args.grid = None
args.grid = False
if args.grid:
args.best = None
if args.file is None and args.best is None and args.grid is None:
args.best = False
if args.file is None and not args.best and not args.grid:
report = ReportDatasets(args.excel)
report.report()
if args.excel:
is_test = args_test is not None
Files.open(report.get_file_name(), is_test)
else:
if args.best is not None or args.grid is not None:
if args.best or args.grid:
report = ReportBest(args.score, args.model, args.best, args.grid)
report.report()
else:

View File

@@ -98,3 +98,27 @@ class ArgumentsTest(TestBase):
finally:
os.chdir(path)
self.assertEqual(stderr.getvalue(), f"{NO_ENV}\n")
@patch("sys.stderr", new_callable=StringIO)
def test_overrides(self, stderr):
arguments = self.build_args()
arguments.xset("title")
arguments.xset("dataset", overrides="title", const="sample text")
test_args = ["-n", "3", "-m", "SVC", "-k", "1", "-d", "dataset"]
args = arguments.parse(test_args)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(args.title, "sample text")
@patch("sys.stderr", new_callable=StringIO)
def test_overrides_no_args(self, stderr):
arguments = self.build_args()
arguments.xset("title")
arguments.xset("dataset", overrides="title", const="sample text")
test_args = None
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
stderr.getvalue(),
r"error: the following arguments are required: -m/--model, "
"-k/--key, --title",
)

View File

@@ -25,7 +25,7 @@ class BeBenchmarkTest(TestBase):
def test_be_benchmark_complete(self):
stdout, stderr = self.execute_script(
"be_benchmark", ["-s", self.score, "-q", "1", "-t", "1", "-x", "1"]
"be_benchmark", ["-s", self.score, "-q", "-t", "-x"]
)
self.assertEqual(stderr.getvalue(), "")
# Check output
@@ -60,7 +60,7 @@ class BeBenchmarkTest(TestBase):
def test_be_benchmark_single(self):
stdout, stderr = self.execute_script(
"be_benchmark", ["-s", self.score, "-q", "1"]
"be_benchmark", ["-s", self.score, "-q"]
)
self.assertEqual(stderr.getvalue(), "")
# Check output

View File

@@ -67,7 +67,7 @@ class BeBestTest(TestBase):
def test_be_build_best_report(self):
stdout, _ = self.execute_script(
"be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r", "1"]
"be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r"]
)
expected_data = {
"balance-scale": [

View File

@@ -69,7 +69,7 @@ class BeGridTest(TestBase):
def test_be_grid_no_input(self):
stdout, stderr = self.execute_script(
"be_grid",
["-m", "ODTE", "-s", "f1-weighted", "-q", "1"],
["-m", "ODTE", "-s", "f1-weighted", "-q"],
)
self.assertEqual(stderr.getvalue(), "")
grid_file = os.path.join(

View File

@@ -29,9 +29,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["q"]))
def test_be_list_report_excel_none(self, input_data):
stdout, stderr = self.execute_script(
"be_list", ["-m", "STree", "-x", "1"]
)
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_model")
@@ -43,9 +41,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["2", "q"]))
def test_be_list_report_excel(self, input_data):
stdout, stderr = self.execute_script(
"be_list", ["-m", "STree", "-x", "1"]
)
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_report_excel")
book = load_workbook(Files.be_list_excel)
@@ -54,9 +50,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["2", "1", "q"]))
def test_be_list_report_excel_twice(self, input_data):
stdout, stderr = self.execute_script(
"be_list", ["-m", "STree", "-x", "1"]
)
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_report_excel_2")
book = load_workbook(Files.be_list_excel)
@@ -87,7 +81,7 @@ class BeListTest(TestBase):
swap_files(Folders.hidden_results, Folders.results, file_name)
try:
# list and move nan result to hidden
stdout, stderr = self.execute_script("be_list", ["--nan", "1"])
stdout, stderr = self.execute_script("be_list", ["--nan"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_nan")
except Exception:
@@ -97,7 +91,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", return_value="q")
def test_be_list_nan_no_nan(self, input_data):
stdout, stderr = self.execute_script("be_list", ["--nan", "1"])
stdout, stderr = self.execute_script("be_list", ["--nan"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_no_nan")

View File

@@ -30,7 +30,7 @@ class BeMainTest(TestBase):
def test_be_main_complete(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
["-s", self.score, "-m", "STree", "--title", "test", "-r"],
)
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
@@ -67,9 +67,7 @@ class BeMainTest(TestBase):
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after
@@ -91,9 +89,7 @@ class BeMainTest(TestBase):
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
@@ -117,9 +113,7 @@ class BeMainTest(TestBase):
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
@@ -142,9 +136,7 @@ class BeMainTest(TestBase):
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after

View File

@@ -18,7 +18,7 @@ class BePrintStrees(TestBase):
for name in self.datasets:
stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1"],
["-d", name, "-q"],
)
file_name = os.path.join(Folders.img, f"stree_{name}.png")
self.files.append(file_name)
@@ -33,7 +33,7 @@ class BePrintStrees(TestBase):
for name in self.datasets:
stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1", "-c", "1"],
["-d", name, "-q", "-c"],
)
file_name = os.path.join(Folders.img, f"stree_{name}.png")
self.files.append(file_name)

View File

@@ -35,7 +35,7 @@ class BeReportTest(TestBase):
def test_be_report_compare(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script(
"be_report", ["-f", file_name, "-c", "1"]
"be_report", ["-f", file_name, "-c"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_compared")
@@ -54,7 +54,7 @@ class BeReportTest(TestBase):
self.assertEqual(line, output_text[index])
def test_be_report_datasets_excel(self):
stdout, stderr = self.execute_script("be_report", ["-x", "1"])
stdout, stderr = self.execute_script("be_report", ["-x"])
self.assertEqual(stderr.getvalue(), "")
file_name = f"report_datasets{self.ext}"
with open(os.path.join(self.test_files, file_name)) as f:
@@ -77,14 +77,14 @@ class BeReportTest(TestBase):
def test_be_report_best(self):
stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-b", "1"]
"be_report", ["-s", "accuracy", "-m", "STree", "-b"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best")
def test_be_report_grid(self):
stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"]
"be_report", ["-s", "accuracy", "-m", "STree", "-g"]
)
self.assertEqual(stderr.getvalue(), "")
file_name = "report_grid.test"
@@ -101,7 +101,7 @@ class BeReportTest(TestBase):
def test_be_report_best_both(self):
stdout, stderr = self.execute_script(
"be_report",
["-s", "accuracy", "-m", "STree", "-b", "1", "-g", "1"],
["-s", "accuracy", "-m", "STree", "-b", "-g"],
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best")
@@ -110,7 +110,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script(
"be_report",
["-f", file_name, "-x", "1", "-c", "1"],
["-f", file_name, "-x", "-c"],
)
file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx")
@@ -125,7 +125,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script(
"be_report",
["-f", file_name, "-x", "1"],
["-f", file_name, "-x"],
)
file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx")
@@ -140,7 +140,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json"
stdout, stderr = self.execute_script(
"be_report",
["-f", file_name, "-q", "1"],
["-f", file_name, "-q"],
)
file_name = os.path.join(
Folders.results, file_name.replace(".json", ".sql")

View File

@@ -1,6 +1,6 @@
*************************************************************************************************************************
* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
* test *
* Test with only one dataset *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.06 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *