Complete be_print_strees

This commit is contained in:
2022-05-09 01:34:25 +02:00
parent b0c94d4983
commit ca96d05124
2 changed files with 53 additions and 35 deletions

View File

@@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders
from benchmark.Arguments import Arguments
def compute_stree(X, y, random_state):
clf = Stree(random_state=random_state)
clf.fit(X, y)
return clf
def load_hyperparams(score_name, model_name):
grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name)
@@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name):
return json.load(f)
def hyperparam_filter(hyperparams):
res = {}
for key, value in hyperparams.items():
if key.startswith("base_estimator"):
newkey = key.split("__")[1]
res[newkey] = value
return res
# def hyperparam_filter(hyperparams):
# res = {}
# for key, value in hyperparams.items():
# if key.startswith("base_estimator"):
# newkey = key.split("__")[1]
# res[newkey] = value
# return res
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
@@ -89,9 +83,7 @@ def main(args_test=None):
if dataset == args.dataset or args.dataset == "all":
X, y = dt.load(dataset)
clf = Stree(random_state=random_state)
hyperparams_dataset = hyperparam_filter(
hyperparameters[dataset][1]
)
hyperparams_dataset = hyperparameters[dataset][1]
clf.set_params(**hyperparams_dataset)
clf.fit(X, y)
print_stree(clf, dataset, X, y, args.color, args.quiet)

View File

@@ -1,34 +1,60 @@
import shutil
from io import StringIO
from unittest.mock import patch
from ...Results import Report
from ...Utils import Files
import os
import hashlib
from ...Utils import Folders
from ..TestBase import TestBase
class BePrintStrees(TestBase):
def setUp(self):
self.prepare_scripts_env()
source = Files.grid_output("accuracy", "STree")
target = Files.grid_output("accuracy", "STree")
shutil.copy(source, target)
self.score = "accuracy"
self.files = [target]
self.files = []
self.datasets = ["balloons", "balance-scale"]
self.expected = {
"balloons": {
"color": "b2342cc27a4ab495970616346bedf73b",
"gray": "a9bc4d2041f2869a93164a548f6ad986",
},
"balance-scale": {
"color": "2e85d66de1ae838d01a3f327397a50c8",
"gray": "30f325134d4b5153c9e6ecbcae7b6d1f",
},
}
def tearDown(self) -> None:
self.remove_files(self.files, ".")
return super().tearDown()
# def test_be_benchmark_dataset(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-m", "STree", "-d", "balloons", "--title", "test"],
# )
# self.check_output_lines(
# stdout=stdout,
# file_name="be_main_dataset",
# lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
# )
def hash_file(self, name):
file_name = os.path.join(Folders.img, f"{name}.png")
self.files.append(file_name)
self.assertTrue(os.path.exists(file_name))
with open(file_name, "rb") as f:
return hashlib.md5(f.read()).hexdigest(), file_name
def test_be_print_strees_dataset_bn(self):
for name in self.datasets:
stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1"],
)
computed_hash, file_name = self.hash_file(f"stree_{name}")
self.assertEqual(
stdout.getvalue(), f"File {file_name} generated\n"
)
self.assertEqual(computed_hash, self.expected[name]["gray"])
def test_be_print_strees_dataset_color(self):
for name in self.datasets:
stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1", "-c", "1"],
)
computed_hash, file_name = self.hash_file(f"stree_{name}")
self.assertEqual(
stdout.getvalue(), f"File {file_name} generated\n"
)
self.assertEqual(computed_hash, self.expected[name]["color"])
# def test_be_benchmark_complete(self):
# stdout, _ = self.execute_script(