mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
Add overrides to args parse for dataset/title in be_main
This commit is contained in:
@@ -36,6 +36,7 @@ class EnvDefault(argparse.Action):
|
||||
self, envvar, required=True, default=None, mandatory=False, **kwargs
|
||||
):
|
||||
self._args = EnvData.load()
|
||||
self._overrides = {}
|
||||
if required and not mandatory:
|
||||
default = self._args[envvar]
|
||||
required = False
|
||||
@@ -51,20 +52,22 @@ class Arguments:
|
||||
def __init__(self):
|
||||
self.ap = argparse.ArgumentParser()
|
||||
models_data = Models.define_models(random_state=0)
|
||||
self._overrides = {}
|
||||
self.parameters = {
|
||||
"best": [
|
||||
("-b", "--best"),
|
||||
{
|
||||
"type": str,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "best results of models",
|
||||
},
|
||||
],
|
||||
"color": [
|
||||
("-c", "--color"),
|
||||
{
|
||||
"type": bool,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "use colors for the tree",
|
||||
},
|
||||
@@ -72,8 +75,9 @@ class Arguments:
|
||||
"compare": [
|
||||
("-c", "--compare"),
|
||||
{
|
||||
"type": bool,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "Compare accuracy with best results",
|
||||
},
|
||||
],
|
||||
@@ -81,6 +85,8 @@ class Arguments:
|
||||
("-d", "--dataset"),
|
||||
{
|
||||
"type": str,
|
||||
"envvar": "dataset", # for compatiblity with EnvDefault
|
||||
"action": EnvDefault,
|
||||
"required": False,
|
||||
"help": "dataset to work with",
|
||||
},
|
||||
@@ -88,8 +94,8 @@ class Arguments:
|
||||
"excel": [
|
||||
("-x", "--excel"),
|
||||
{
|
||||
"type": bool,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Generate Excel File",
|
||||
},
|
||||
@@ -101,16 +107,17 @@ class Arguments:
|
||||
"grid": [
|
||||
("-g", "--grid"),
|
||||
{
|
||||
"type": str,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "grid results of model",
|
||||
},
|
||||
],
|
||||
"grid_paramfile": [
|
||||
("-g", "--grid_paramfile"),
|
||||
{
|
||||
"type": bool,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Use best hyperparams file?",
|
||||
},
|
||||
@@ -118,8 +125,8 @@ class Arguments:
|
||||
"hidden": [
|
||||
("--hidden",),
|
||||
{
|
||||
"type": str,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Show hidden results",
|
||||
},
|
||||
@@ -140,8 +147,8 @@ class Arguments:
|
||||
"lose": [
|
||||
("-l", "--lose"),
|
||||
{
|
||||
"type": bool,
|
||||
"default": False,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"help": "show lose results",
|
||||
},
|
||||
@@ -178,8 +185,9 @@ class Arguments:
|
||||
"nan": [
|
||||
("--nan",),
|
||||
{
|
||||
"type": bool,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "Move nan results to hidden folder",
|
||||
},
|
||||
],
|
||||
@@ -205,7 +213,7 @@ class Arguments:
|
||||
"paramfile": [
|
||||
("-f", "--paramfile"),
|
||||
{
|
||||
"type": bool,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "Use best hyperparams file?",
|
||||
@@ -224,7 +232,7 @@ class Arguments:
|
||||
"quiet": [
|
||||
("-q", "--quiet"),
|
||||
{
|
||||
"type": bool,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
},
|
||||
@@ -232,7 +240,7 @@ class Arguments:
|
||||
"report": [
|
||||
("-r", "--report"),
|
||||
{
|
||||
"type": bool,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"required": False,
|
||||
"help": "Report results",
|
||||
@@ -250,23 +258,27 @@ class Arguments:
|
||||
],
|
||||
"sql": [
|
||||
("-q", "--sql"),
|
||||
{"type": bool, "required": False, "help": "Generate SQL File"},
|
||||
{
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Generate SQL File",
|
||||
},
|
||||
],
|
||||
"stratified": [
|
||||
("-t", "--stratified"),
|
||||
{
|
||||
"action": EnvDefault,
|
||||
"envvar": "stratified",
|
||||
"type": str,
|
||||
"required": True,
|
||||
"required": False,
|
||||
"help": "Stratified",
|
||||
},
|
||||
],
|
||||
"tex_output": [
|
||||
("-t", "--tex-output"),
|
||||
{
|
||||
"type": bool,
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Generate Tex file with the table",
|
||||
},
|
||||
@@ -278,8 +290,8 @@ class Arguments:
|
||||
"win": [
|
||||
("-w", "--win"),
|
||||
{
|
||||
"type": bool,
|
||||
"default": False,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"help": "show win results",
|
||||
},
|
||||
@@ -287,12 +299,20 @@ class Arguments:
|
||||
}
|
||||
|
||||
def xset(self, *arg_name, **kwargs):
|
||||
names, default = self.parameters[arg_name[0]]
|
||||
names, parameters = self.parameters[arg_name[0]]
|
||||
if "overrides" in kwargs:
|
||||
self._overrides[names[0]] = (kwargs["overrides"], kwargs["const"])
|
||||
del kwargs["overrides"]
|
||||
self.ap.add_argument(
|
||||
*names,
|
||||
**{**default, **kwargs},
|
||||
**{**parameters, **kwargs},
|
||||
)
|
||||
return self
|
||||
|
||||
def parse(self, args=None):
|
||||
for key, (dest_key, value) in self._overrides.items():
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
if key in args:
|
||||
args.extend((f"--{dest_key}", value))
|
||||
return self.ap.parse_args(args)
|
||||
|
@@ -14,7 +14,10 @@ def main(args_test=None):
|
||||
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
|
||||
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
|
||||
arguments.xset("hyperparameters").xset("paramfile").xset("report")
|
||||
arguments.xset("grid_paramfile").xset("dataset")
|
||||
arguments.xset("grid_paramfile")
|
||||
arguments.xset(
|
||||
"dataset", overrides="title", const="Test with only one dataset"
|
||||
)
|
||||
args = arguments.parse(args_test)
|
||||
report = args.report or args.dataset is not None
|
||||
if args.grid_paramfile:
|
||||
|
@@ -17,17 +17,17 @@ def main(args_test=None):
|
||||
arguments.xset("score", required=False)
|
||||
args = arguments.parse(args_test)
|
||||
if args.best:
|
||||
args.grid = None
|
||||
args.grid = False
|
||||
if args.grid:
|
||||
args.best = None
|
||||
if args.file is None and args.best is None and args.grid is None:
|
||||
args.best = False
|
||||
if args.file is None and not args.best and not args.grid:
|
||||
report = ReportDatasets(args.excel)
|
||||
report.report()
|
||||
if args.excel:
|
||||
is_test = args_test is not None
|
||||
Files.open(report.get_file_name(), is_test)
|
||||
else:
|
||||
if args.best is not None or args.grid is not None:
|
||||
if args.best or args.grid:
|
||||
report = ReportBest(args.score, args.model, args.best, args.grid)
|
||||
report.report()
|
||||
else:
|
||||
|
@@ -98,3 +98,27 @@ class ArgumentsTest(TestBase):
|
||||
finally:
|
||||
os.chdir(path)
|
||||
self.assertEqual(stderr.getvalue(), f"{NO_ENV}\n")
|
||||
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def test_overrides(self, stderr):
|
||||
arguments = self.build_args()
|
||||
arguments.xset("title")
|
||||
arguments.xset("dataset", overrides="title", const="sample text")
|
||||
test_args = ["-n", "3", "-m", "SVC", "-k", "1", "-d", "dataset"]
|
||||
args = arguments.parse(test_args)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(args.title, "sample text")
|
||||
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def test_overrides_no_args(self, stderr):
|
||||
arguments = self.build_args()
|
||||
arguments.xset("title")
|
||||
arguments.xset("dataset", overrides="title", const="sample text")
|
||||
test_args = None
|
||||
with self.assertRaises(SystemExit):
|
||||
arguments.parse(test_args)
|
||||
self.assertRegexpMatches(
|
||||
stderr.getvalue(),
|
||||
r"error: the following arguments are required: -m/--model, "
|
||||
"-k/--key, --title",
|
||||
)
|
||||
|
@@ -25,7 +25,7 @@ class BeBenchmarkTest(TestBase):
|
||||
|
||||
def test_be_benchmark_complete(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_benchmark", ["-s", self.score, "-q", "1", "-t", "1", "-x", "1"]
|
||||
"be_benchmark", ["-s", self.score, "-q", "-t", "-x"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
# Check output
|
||||
@@ -60,7 +60,7 @@ class BeBenchmarkTest(TestBase):
|
||||
|
||||
def test_be_benchmark_single(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_benchmark", ["-s", self.score, "-q", "1"]
|
||||
"be_benchmark", ["-s", self.score, "-q"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
# Check output
|
||||
|
@@ -67,7 +67,7 @@ class BeBestTest(TestBase):
|
||||
|
||||
def test_be_build_best_report(self):
|
||||
stdout, _ = self.execute_script(
|
||||
"be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r", "1"]
|
||||
"be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r"]
|
||||
)
|
||||
expected_data = {
|
||||
"balance-scale": [
|
||||
|
@@ -69,7 +69,7 @@ class BeGridTest(TestBase):
|
||||
def test_be_grid_no_input(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_grid",
|
||||
["-m", "ODTE", "-s", "f1-weighted", "-q", "1"],
|
||||
["-m", "ODTE", "-s", "f1-weighted", "-q"],
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
grid_file = os.path.join(
|
||||
|
@@ -29,9 +29,7 @@ class BeListTest(TestBase):
|
||||
|
||||
@patch("benchmark.Results.get_input", side_effect=iter(["q"]))
|
||||
def test_be_list_report_excel_none(self, input_data):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_list", ["-m", "STree", "-x", "1"]
|
||||
)
|
||||
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "be_list_model")
|
||||
|
||||
@@ -43,9 +41,7 @@ class BeListTest(TestBase):
|
||||
|
||||
@patch("benchmark.Results.get_input", side_effect=iter(["2", "q"]))
|
||||
def test_be_list_report_excel(self, input_data):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_list", ["-m", "STree", "-x", "1"]
|
||||
)
|
||||
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "be_list_report_excel")
|
||||
book = load_workbook(Files.be_list_excel)
|
||||
@@ -54,9 +50,7 @@ class BeListTest(TestBase):
|
||||
|
||||
@patch("benchmark.Results.get_input", side_effect=iter(["2", "1", "q"]))
|
||||
def test_be_list_report_excel_twice(self, input_data):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_list", ["-m", "STree", "-x", "1"]
|
||||
)
|
||||
stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "be_list_report_excel_2")
|
||||
book = load_workbook(Files.be_list_excel)
|
||||
@@ -87,7 +81,7 @@ class BeListTest(TestBase):
|
||||
swap_files(Folders.hidden_results, Folders.results, file_name)
|
||||
try:
|
||||
# list and move nan result to hidden
|
||||
stdout, stderr = self.execute_script("be_list", ["--nan", "1"])
|
||||
stdout, stderr = self.execute_script("be_list", ["--nan"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "be_list_nan")
|
||||
except Exception:
|
||||
@@ -97,7 +91,7 @@ class BeListTest(TestBase):
|
||||
|
||||
@patch("benchmark.Results.get_input", return_value="q")
|
||||
def test_be_list_nan_no_nan(self, input_data):
|
||||
stdout, stderr = self.execute_script("be_list", ["--nan", "1"])
|
||||
stdout, stderr = self.execute_script("be_list", ["--nan"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "be_list_no_nan")
|
||||
|
||||
|
@@ -30,7 +30,7 @@ class BeMainTest(TestBase):
|
||||
def test_be_main_complete(self):
|
||||
stdout, _ = self.execute_script(
|
||||
"be_main",
|
||||
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
|
||||
["-s", self.score, "-m", "STree", "--title", "test", "-r"],
|
||||
)
|
||||
# keep the report name to delete it after
|
||||
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
|
||||
@@ -67,9 +67,7 @@ class BeMainTest(TestBase):
|
||||
"--title",
|
||||
"test",
|
||||
"-f",
|
||||
"1",
|
||||
"-r",
|
||||
"1",
|
||||
],
|
||||
)
|
||||
# keep the report name to delete it after
|
||||
@@ -91,9 +89,7 @@ class BeMainTest(TestBase):
|
||||
"--title",
|
||||
"test",
|
||||
"-f",
|
||||
"1",
|
||||
"-r",
|
||||
"1",
|
||||
],
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
@@ -117,9 +113,7 @@ class BeMainTest(TestBase):
|
||||
"--title",
|
||||
"test",
|
||||
"-g",
|
||||
"1",
|
||||
"-r",
|
||||
"1",
|
||||
],
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
@@ -142,9 +136,7 @@ class BeMainTest(TestBase):
|
||||
"--title",
|
||||
"test",
|
||||
"-g",
|
||||
"1",
|
||||
"-r",
|
||||
"1",
|
||||
],
|
||||
)
|
||||
# keep the report name to delete it after
|
||||
|
@@ -18,7 +18,7 @@ class BePrintStrees(TestBase):
|
||||
for name in self.datasets:
|
||||
stdout, _ = self.execute_script(
|
||||
"be_print_strees",
|
||||
["-d", name, "-q", "1"],
|
||||
["-d", name, "-q"],
|
||||
)
|
||||
file_name = os.path.join(Folders.img, f"stree_{name}.png")
|
||||
self.files.append(file_name)
|
||||
@@ -33,7 +33,7 @@ class BePrintStrees(TestBase):
|
||||
for name in self.datasets:
|
||||
stdout, _ = self.execute_script(
|
||||
"be_print_strees",
|
||||
["-d", name, "-q", "1", "-c", "1"],
|
||||
["-d", name, "-q", "-c"],
|
||||
)
|
||||
file_name = os.path.join(Folders.img, f"stree_{name}.png")
|
||||
self.files.append(file_name)
|
||||
|
@@ -35,7 +35,7 @@ class BeReportTest(TestBase):
|
||||
def test_be_report_compare(self):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report", ["-f", file_name, "-c", "1"]
|
||||
"be_report", ["-f", file_name, "-c"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report_compared")
|
||||
@@ -54,7 +54,7 @@ class BeReportTest(TestBase):
|
||||
self.assertEqual(line, output_text[index])
|
||||
|
||||
def test_be_report_datasets_excel(self):
|
||||
stdout, stderr = self.execute_script("be_report", ["-x", "1"])
|
||||
stdout, stderr = self.execute_script("be_report", ["-x"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
file_name = f"report_datasets{self.ext}"
|
||||
with open(os.path.join(self.test_files, file_name)) as f:
|
||||
@@ -77,14 +77,14 @@ class BeReportTest(TestBase):
|
||||
|
||||
def test_be_report_best(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report", ["-s", "accuracy", "-m", "STree", "-b", "1"]
|
||||
"be_report", ["-s", "accuracy", "-m", "STree", "-b"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report_best")
|
||||
|
||||
def test_be_report_grid(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"]
|
||||
"be_report", ["-s", "accuracy", "-m", "STree", "-g"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
file_name = "report_grid.test"
|
||||
@@ -101,7 +101,7 @@ class BeReportTest(TestBase):
|
||||
def test_be_report_best_both(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report",
|
||||
["-s", "accuracy", "-m", "STree", "-b", "1", "-g", "1"],
|
||||
["-s", "accuracy", "-m", "STree", "-b", "-g"],
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report_best")
|
||||
@@ -110,7 +110,7 @@ class BeReportTest(TestBase):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report",
|
||||
["-f", file_name, "-x", "1", "-c", "1"],
|
||||
["-f", file_name, "-x", "-c"],
|
||||
)
|
||||
file_name = os.path.join(
|
||||
Folders.results, file_name.replace(".json", ".xlsx")
|
||||
@@ -125,7 +125,7 @@ class BeReportTest(TestBase):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report",
|
||||
["-f", file_name, "-x", "1"],
|
||||
["-f", file_name, "-x"],
|
||||
)
|
||||
file_name = os.path.join(
|
||||
Folders.results, file_name.replace(".json", ".xlsx")
|
||||
@@ -140,7 +140,7 @@ class BeReportTest(TestBase):
|
||||
file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json"
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report",
|
||||
["-f", file_name, "-q", "1"],
|
||||
["-f", file_name, "-q"],
|
||||
)
|
||||
file_name = os.path.join(
|
||||
Folders.results, file_name.replace(".json", ".sql")
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[94m*************************************************************************************************************************
|
||||
[94m* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
|
||||
[94m* test *
|
||||
[94m* Test with only one dataset *
|
||||
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
|
||||
[94m* Execution took 0.06 seconds, 0.00 hours, on iMac27 *
|
||||
[94m* Score is accuracy *
|
||||
|
Reference in New Issue
Block a user