mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 16:05:54 +00:00
Fix be_print_strees issues
This commit is contained in:
@@ -16,13 +16,13 @@ def load_hyperparams(score_name, model_name):
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
# def hyperparam_filter(hyperparams):
|
def hyperparam_filter(hyperparams):
|
||||||
# res = {}
|
res = {}
|
||||||
# for key, value in hyperparams.items():
|
for key, value in hyperparams.items():
|
||||||
# if key.startswith("base_estimator"):
|
if key.startswith("base_estimator"):
|
||||||
# newkey = key.split("__")[1]
|
newkey = key.split("__")[1]
|
||||||
# res[newkey] = value
|
res[newkey] = value
|
||||||
# return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
|
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
|
||||||
@@ -76,14 +76,16 @@ def main(args_test=None):
|
|||||||
arguments = Arguments()
|
arguments = Arguments()
|
||||||
arguments.xset("color").xset("dataset", default="all").xset("quiet")
|
arguments.xset("color").xset("dataset", default="all").xset("quiet")
|
||||||
args = arguments.parse(args_test)
|
args = arguments.parse(args_test)
|
||||||
hyperparameters = load_hyperparams("accuracy", "STree")
|
hyperparameters = load_hyperparams("accuracy", "ODTE")
|
||||||
random_state = 57
|
random_state = 57
|
||||||
dt = Datasets()
|
dt = Datasets()
|
||||||
for dataset in dt:
|
for dataset in dt:
|
||||||
if dataset == args.dataset or args.dataset == "all":
|
if dataset == args.dataset or args.dataset == "all":
|
||||||
X, y = dt.load(dataset)
|
X, y = dt.load(dataset)
|
||||||
clf = Stree(random_state=random_state)
|
clf = Stree(random_state=random_state)
|
||||||
hyperparams_dataset = hyperparameters[dataset][1]
|
hyperparams_dataset = hyperparam_filter(
|
||||||
|
hyperparameters[dataset][1]
|
||||||
|
)
|
||||||
clf.set_params(**hyperparams_dataset)
|
clf.set_params(**hyperparams_dataset)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
print_stree(clf, dataset, X, y, args.color, args.quiet)
|
print_stree(clf, dataset, X, y, args.color, args.quiet)
|
||||||
|
@@ -55,84 +55,3 @@ class BePrintStrees(TestBase):
|
|||||||
stdout.getvalue(), f"File {file_name} generated\n"
|
stdout.getvalue(), f"File {file_name} generated\n"
|
||||||
)
|
)
|
||||||
self.assertEqual(computed_hash, self.expected[name]["color"])
|
self.assertEqual(computed_hash, self.expected[name]["color"])
|
||||||
|
|
||||||
# def test_be_benchmark_complete(self):
|
|
||||||
# stdout, _ = self.execute_script(
|
|
||||||
# "be_main",
|
|
||||||
# ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
|
|
||||||
# )
|
|
||||||
# # keep the report name to delete it after
|
|
||||||
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
|
|
||||||
# self.files.append(report_name)
|
|
||||||
# self.check_output_lines(
|
|
||||||
# stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def test_be_benchmark_no_report(self):
|
|
||||||
# stdout, _ = self.execute_script(
|
|
||||||
# "be_main",
|
|
||||||
# ["-s", self.score, "-m", "STree", "--title", "test"],
|
|
||||||
# )
|
|
||||||
# # keep the report name to delete it after
|
|
||||||
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
|
|
||||||
# self.files.append(report_name)
|
|
||||||
# report = Report(file_name=report_name)
|
|
||||||
# with patch(self.output, new=StringIO()) as stdout:
|
|
||||||
# report.report()
|
|
||||||
# self.check_output_lines(
|
|
||||||
# stdout,
|
|
||||||
# "be_main_complete",
|
|
||||||
# [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def test_be_benchmark_best_params(self):
|
|
||||||
# stdout, _ = self.execute_script(
|
|
||||||
# "be_main",
|
|
||||||
# [
|
|
||||||
# "-s",
|
|
||||||
# self.score,
|
|
||||||
# "-m",
|
|
||||||
# "STree",
|
|
||||||
# "--title",
|
|
||||||
# "test",
|
|
||||||
# "-f",
|
|
||||||
# "1",
|
|
||||||
# "-r",
|
|
||||||
# "1",
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# # keep the report name to delete it after
|
|
||||||
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
|
|
||||||
# self.files.append(report_name)
|
|
||||||
# self.check_output_lines(
|
|
||||||
# stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def test_be_benchmark_grid_params(self):
|
|
||||||
# stdout, _ = self.execute_script(
|
|
||||||
# "be_main",
|
|
||||||
# [
|
|
||||||
# "-s",
|
|
||||||
# self.score,
|
|
||||||
# "-m",
|
|
||||||
# "STree",
|
|
||||||
# "--title",
|
|
||||||
# "test",
|
|
||||||
# "-g",
|
|
||||||
# "1",
|
|
||||||
# "-r",
|
|
||||||
# "1",
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# # keep the report name to delete it after
|
|
||||||
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
|
|
||||||
# self.files.append(report_name)
|
|
||||||
# self.check_output_lines(
|
|
||||||
# stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def test_be_benchmark_no_data(self):
|
|
||||||
# stdout, _ = self.execute_script(
|
|
||||||
# "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
|
|
||||||
# )
|
|
||||||
# self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n")
|
|
||||||
|
@@ -6,3 +6,4 @@ xlsxwriter
|
|||||||
openpyxl
|
openpyxl
|
||||||
tqdm
|
tqdm
|
||||||
xgboost
|
xgboost
|
||||||
|
graphviz
|
||||||
|
Reference in New Issue
Block a user