mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 07:55:54 +00:00
Complete be_print_strees
This commit is contained in:
@@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders
|
||||
from benchmark.Arguments import Arguments
|
||||
|
||||
|
||||
def compute_stree(X, y, random_state):
|
||||
clf = Stree(random_state=random_state)
|
||||
clf.fit(X, y)
|
||||
return clf
|
||||
|
||||
|
||||
def load_hyperparams(score_name, model_name):
|
||||
grid_file = os.path.join(
|
||||
Folders.results, Files.grid_output(score_name, model_name)
|
||||
@@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def hyperparam_filter(hyperparams):
|
||||
res = {}
|
||||
for key, value in hyperparams.items():
|
||||
if key.startswith("base_estimator"):
|
||||
newkey = key.split("__")[1]
|
||||
res[newkey] = value
|
||||
return res
|
||||
# def hyperparam_filter(hyperparams):
|
||||
# res = {}
|
||||
# for key, value in hyperparams.items():
|
||||
# if key.startswith("base_estimator"):
|
||||
# newkey = key.split("__")[1]
|
||||
# res[newkey] = value
|
||||
# return res
|
||||
|
||||
|
||||
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
|
||||
@@ -89,9 +83,7 @@ def main(args_test=None):
|
||||
if dataset == args.dataset or args.dataset == "all":
|
||||
X, y = dt.load(dataset)
|
||||
clf = Stree(random_state=random_state)
|
||||
hyperparams_dataset = hyperparam_filter(
|
||||
hyperparameters[dataset][1]
|
||||
)
|
||||
hyperparams_dataset = hyperparameters[dataset][1]
|
||||
clf.set_params(**hyperparams_dataset)
|
||||
clf.fit(X, y)
|
||||
print_stree(clf, dataset, X, y, args.color, args.quiet)
|
||||
|
@@ -1,34 +1,60 @@
|
||||
import shutil
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from ...Results import Report
|
||||
from ...Utils import Files
|
||||
import os
|
||||
import hashlib
|
||||
from ...Utils import Folders
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class BePrintStrees(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
source = Files.grid_output("accuracy", "STree")
|
||||
target = Files.grid_output("accuracy", "STree")
|
||||
shutil.copy(source, target)
|
||||
self.score = "accuracy"
|
||||
self.files = [target]
|
||||
self.files = []
|
||||
self.datasets = ["balloons", "balance-scale"]
|
||||
self.expected = {
|
||||
"balloons": {
|
||||
"color": "b2342cc27a4ab495970616346bedf73b",
|
||||
"gray": "a9bc4d2041f2869a93164a548f6ad986",
|
||||
},
|
||||
"balance-scale": {
|
||||
"color": "2e85d66de1ae838d01a3f327397a50c8",
|
||||
"gray": "30f325134d4b5153c9e6ecbcae7b6d1f",
|
||||
},
|
||||
}
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.remove_files(self.files, ".")
|
||||
return super().tearDown()
|
||||
|
||||
# def test_be_benchmark_dataset(self):
|
||||
# stdout, _ = self.execute_script(
|
||||
# "be_main",
|
||||
# ["-m", "STree", "-d", "balloons", "--title", "test"],
|
||||
# )
|
||||
# self.check_output_lines(
|
||||
# stdout=stdout,
|
||||
# file_name="be_main_dataset",
|
||||
# lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
|
||||
# )
|
||||
def hash_file(self, name):
|
||||
file_name = os.path.join(Folders.img, f"{name}.png")
|
||||
self.files.append(file_name)
|
||||
self.assertTrue(os.path.exists(file_name))
|
||||
with open(file_name, "rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest(), file_name
|
||||
|
||||
def test_be_print_strees_dataset_bn(self):
|
||||
for name in self.datasets:
|
||||
stdout, _ = self.execute_script(
|
||||
"be_print_strees",
|
||||
["-d", name, "-q", "1"],
|
||||
)
|
||||
computed_hash, file_name = self.hash_file(f"stree_{name}")
|
||||
self.assertEqual(
|
||||
stdout.getvalue(), f"File {file_name} generated\n"
|
||||
)
|
||||
self.assertEqual(computed_hash, self.expected[name]["gray"])
|
||||
|
||||
def test_be_print_strees_dataset_color(self):
|
||||
for name in self.datasets:
|
||||
stdout, _ = self.execute_script(
|
||||
"be_print_strees",
|
||||
["-d", name, "-q", "1", "-c", "1"],
|
||||
)
|
||||
computed_hash, file_name = self.hash_file(f"stree_{name}")
|
||||
self.assertEqual(
|
||||
stdout.getvalue(), f"File {file_name} generated\n"
|
||||
)
|
||||
self.assertEqual(computed_hash, self.expected[name]["color"])
|
||||
|
||||
# def test_be_benchmark_complete(self):
|
||||
# stdout, _ = self.execute_script(
|
||||
|
Reference in New Issue
Block a user