diff --git a/benchmark/scripts/be_print_strees.py b/benchmark/scripts/be_print_strees.py index 5122cbe..0a55a19 100755 --- a/benchmark/scripts/be_print_strees.py +++ b/benchmark/scripts/be_print_strees.py @@ -16,13 +16,13 @@ def load_hyperparams(score_name, model_name): return json.load(f) -# def hyperparam_filter(hyperparams): -# res = {} -# for key, value in hyperparams.items(): -# if key.startswith("base_estimator"): -# newkey = key.split("__")[1] -# res[newkey] = value -# return res +def hyperparam_filter(hyperparams): + res = {} + for key, value in hyperparams.items(): + if key.startswith("base_estimator"): + newkey = key.split("__")[1] + res[newkey] = value + return res def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes): @@ -76,14 +76,16 @@ def main(args_test=None): arguments = Arguments() arguments.xset("color").xset("dataset", default="all").xset("quiet") args = arguments.parse(args_test) - hyperparameters = load_hyperparams("accuracy", "STree") + hyperparameters = load_hyperparams("accuracy", "ODTE") random_state = 57 dt = Datasets() for dataset in dt: if dataset == args.dataset or args.dataset == "all": X, y = dt.load(dataset) clf = Stree(random_state=random_state) - hyperparams_dataset = hyperparameters[dataset][1] + hyperparams_dataset = hyperparam_filter( + hyperparameters[dataset][1] + ) clf.set_params(**hyperparams_dataset) clf.fit(X, y) print_stree(clf, dataset, X, y, args.color, args.quiet) diff --git a/benchmark/tests/scripts/Be_Print_Strees_test.py b/benchmark/tests/scripts/Be_Print_Strees_test.py index bc9ead0..67a4da2 100644 --- a/benchmark/tests/scripts/Be_Print_Strees_test.py +++ b/benchmark/tests/scripts/Be_Print_Strees_test.py @@ -55,84 +55,3 @@ class BePrintStrees(TestBase): stdout.getvalue(), f"File {file_name} generated\n" ) self.assertEqual(computed_hash, self.expected[name]["color"]) - - # def test_be_benchmark_complete(self): - # stdout, _ = self.execute_script( - # "be_main", - # ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"], - # ) - # # keep the report name to delete it after - # report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] - # self.files.append(report_name) - # self.check_output_lines( - # stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] - # ) - - # def test_be_benchmark_no_report(self): - # stdout, _ = self.execute_script( - # "be_main", - # ["-s", self.score, "-m", "STree", "--title", "test"], - # ) - # # keep the report name to delete it after - # report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] - # self.files.append(report_name) - # report = Report(file_name=report_name) - # with patch(self.output, new=StringIO()) as stdout: - # report.report() - # self.check_output_lines( - # stdout, - # "be_main_complete", - # [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14], - # ) - - # def test_be_benchmark_best_params(self): - # stdout, _ = self.execute_script( - # "be_main", - # [ - # "-s", - # self.score, - # "-m", - # "STree", - # "--title", - # "test", - # "-f", - # "1", - # "-r", - # "1", - # ], - # ) - # # keep the report name to delete it after - # report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] - # self.files.append(report_name) - # self.check_output_lines( - # stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] - # ) - - # def test_be_benchmark_grid_params(self): - # stdout, _ = self.execute_script( - # "be_main", - # [ - # "-s", - # self.score, - # "-m", - # "STree", - # "--title", - # "test", - # "-g", - # "1", - # "-r", - # "1", - # ], - # ) - # # keep the report name to delete it after - # report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] - # self.files.append(report_name) - # self.check_output_lines( - # stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] - # ) - - # def test_be_benchmark_no_data(self): - # stdout, _ = self.execute_script( - # "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"] - # ) - # self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n") diff --git a/requirements.txt b/requirements.txt index 3af1412..e5053c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ xlsxwriter openpyxl tqdm xgboost +graphviz