Complete be_print_strees

This commit is contained in:
2022-05-09 01:34:25 +02:00
parent b0c94d4983
commit ca96d05124
2 changed files with 53 additions and 35 deletions

View File

@@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders
from benchmark.Arguments import Arguments from benchmark.Arguments import Arguments
def compute_stree(X, y, random_state):
clf = Stree(random_state=random_state)
clf.fit(X, y)
return clf
def load_hyperparams(score_name, model_name): def load_hyperparams(score_name, model_name):
grid_file = os.path.join( grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name) Folders.results, Files.grid_output(score_name, model_name)
@@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name):
return json.load(f) return json.load(f)
def hyperparam_filter(hyperparams): # def hyperparam_filter(hyperparams):
res = {} # res = {}
for key, value in hyperparams.items(): # for key, value in hyperparams.items():
if key.startswith("base_estimator"): # if key.startswith("base_estimator"):
newkey = key.split("__")[1] # newkey = key.split("__")[1]
res[newkey] = value # res[newkey] = value
return res # return res
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes): def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
@@ -89,9 +83,7 @@ def main(args_test=None):
if dataset == args.dataset or args.dataset == "all": if dataset == args.dataset or args.dataset == "all":
X, y = dt.load(dataset) X, y = dt.load(dataset)
clf = Stree(random_state=random_state) clf = Stree(random_state=random_state)
hyperparams_dataset = hyperparam_filter( hyperparams_dataset = hyperparameters[dataset][1]
hyperparameters[dataset][1]
)
clf.set_params(**hyperparams_dataset) clf.set_params(**hyperparams_dataset)
clf.fit(X, y) clf.fit(X, y)
print_stree(clf, dataset, X, y, args.color, args.quiet) print_stree(clf, dataset, X, y, args.color, args.quiet)

View File

@@ -1,34 +1,60 @@
import shutil import os
from io import StringIO import hashlib
from unittest.mock import patch from ...Utils import Folders
from ...Results import Report
from ...Utils import Files
from ..TestBase import TestBase from ..TestBase import TestBase
class BePrintStrees(TestBase): class BePrintStrees(TestBase):
def setUp(self): def setUp(self):
self.prepare_scripts_env() self.prepare_scripts_env()
source = Files.grid_output("accuracy", "STree")
target = Files.grid_output("accuracy", "STree")
shutil.copy(source, target)
self.score = "accuracy" self.score = "accuracy"
self.files = [target] self.files = []
self.datasets = ["balloons", "balance-scale"]
self.expected = {
"balloons": {
"color": "b2342cc27a4ab495970616346bedf73b",
"gray": "a9bc4d2041f2869a93164a548f6ad986",
},
"balance-scale": {
"color": "2e85d66de1ae838d01a3f327397a50c8",
"gray": "30f325134d4b5153c9e6ecbcae7b6d1f",
},
}
def tearDown(self) -> None: def tearDown(self) -> None:
self.remove_files(self.files, ".") self.remove_files(self.files, ".")
return super().tearDown() return super().tearDown()
# def test_be_benchmark_dataset(self): def hash_file(self, name):
# stdout, _ = self.execute_script( file_name = os.path.join(Folders.img, f"{name}.png")
# "be_main", self.files.append(file_name)
# ["-m", "STree", "-d", "balloons", "--title", "test"], self.assertTrue(os.path.exists(file_name))
# ) with open(file_name, "rb") as f:
# self.check_output_lines( return hashlib.md5(f.read()).hexdigest(), file_name
# stdout=stdout,
# file_name="be_main_dataset", def test_be_print_strees_dataset_bn(self):
# lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13], for name in self.datasets:
# ) stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1"],
)
computed_hash, file_name = self.hash_file(f"stree_{name}")
self.assertEqual(
stdout.getvalue(), f"File {file_name} generated\n"
)
self.assertEqual(computed_hash, self.expected[name]["gray"])
def test_be_print_strees_dataset_color(self):
for name in self.datasets:
stdout, _ = self.execute_script(
"be_print_strees",
["-d", name, "-q", "1", "-c", "1"],
)
computed_hash, file_name = self.hash_file(f"stree_{name}")
self.assertEqual(
stdout.getvalue(), f"File {file_name} generated\n"
)
self.assertEqual(computed_hash, self.expected[name]["color"])
# def test_be_benchmark_complete(self): # def test_be_benchmark_complete(self):
# stdout, _ = self.execute_script( # stdout, _ = self.execute_script(