mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 08:55:53 +00:00
Complete be_print_strees
This commit is contained in:
@@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders
|
||||
from benchmark.Arguments import Arguments
|
||||
|
||||
|
||||
def compute_stree(X, y, random_state):
|
||||
clf = Stree(random_state=random_state)
|
||||
clf.fit(X, y)
|
||||
return clf
|
||||
|
||||
|
||||
def load_hyperparams(score_name, model_name):
|
||||
grid_file = os.path.join(
|
||||
Folders.results, Files.grid_output(score_name, model_name)
|
||||
@@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def hyperparam_filter(hyperparams):
|
||||
res = {}
|
||||
for key, value in hyperparams.items():
|
||||
if key.startswith("base_estimator"):
|
||||
newkey = key.split("__")[1]
|
||||
res[newkey] = value
|
||||
return res
|
||||
# def hyperparam_filter(hyperparams):
|
||||
# res = {}
|
||||
# for key, value in hyperparams.items():
|
||||
# if key.startswith("base_estimator"):
|
||||
# newkey = key.split("__")[1]
|
||||
# res[newkey] = value
|
||||
# return res
|
||||
|
||||
|
||||
def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
|
||||
@@ -89,9 +83,7 @@ def main(args_test=None):
|
||||
if dataset == args.dataset or args.dataset == "all":
|
||||
X, y = dt.load(dataset)
|
||||
clf = Stree(random_state=random_state)
|
||||
hyperparams_dataset = hyperparam_filter(
|
||||
hyperparameters[dataset][1]
|
||||
)
|
||||
hyperparams_dataset = hyperparameters[dataset][1]
|
||||
clf.set_params(**hyperparams_dataset)
|
||||
clf.fit(X, y)
|
||||
print_stree(clf, dataset, X, y, args.color, args.quiet)
|
||||
|
Reference in New Issue
Block a user