Complete be_print_strees

This commit is contained in:
2022-05-09 01:34:25 +02:00
parent b0c94d4983
commit ca96d05124
2 changed files with 53 additions and 35 deletions

View File

@@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders
from benchmark.Arguments import Arguments
def compute_stree(X, y, random_state):
clf = Stree(random_state=random_state)
clf.fit(X, y)
return clf
def load_hyperparams(score_name, model_name):
grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name)
@@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name):
return json.load(f)
def hyperparam_filter(hyperparams):
res = {}
for key, value in hyperparams.items():
if key.startswith("base_estimator"):
newkey = key.split("__")[1]
res[newkey] = value
return res
# def hyperparam_filter(hyperparams):
# res = {}
# for key, value in hyperparams.items():
# if key.startswith("base_estimator"):
# newkey = key.split("__")[1]
# res[newkey] = value
# return res
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
@@ -89,9 +83,7 @@ def main(args_test=None):
if dataset == args.dataset or args.dataset == "all":
X, y = dt.load(dataset)
clf = Stree(random_state=random_state)
hyperparams_dataset = hyperparam_filter(
hyperparameters[dataset][1]
)
hyperparams_dataset = hyperparameters[dataset][1]
clf.set_params(**hyperparams_dataset)
clf.fit(X, y)
print_stree(clf, dataset, X, y, args.color, args.quiet)