mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-18 00:46:03 +00:00
add kernel hyperparameter subset in gridsearch
This commit is contained in:
@@ -55,11 +55,24 @@ def report_line(line):
|
|||||||
|
|
||||||
|
|
||||||
def report_footer(agg):
|
def report_footer(agg):
|
||||||
print(TextColor.GREEN + f"we have better results {agg['better']:2d} times")
|
print(
|
||||||
print(TextColor.RED + f"we have worse results {agg['worse']:2d} times")
|
TextColor.GREEN
|
||||||
|
+ f"we have better results {agg['better']['items']:2d} times"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
TextColor.RED
|
||||||
|
+ f"we have worse results {agg['worse']['items']:2d} times"
|
||||||
|
)
|
||||||
color = TextColor.LINE1
|
color = TextColor.LINE1
|
||||||
for item in models:
|
for item in models:
|
||||||
print(color + f"{item:10s} used {agg[item]:2d} times")
|
print(
|
||||||
|
color + f"{item:10s} used {agg[item]['items']:2d} times ", end=""
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
color + f"better {agg[item]['better']:2d} times ",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
print(color + f"worse {agg[item]['worse']:2d} times ")
|
||||||
color = (
|
color = (
|
||||||
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
TextColor.LINE2 if color == TextColor.LINE1 else TextColor.LINE1
|
||||||
)
|
)
|
||||||
@@ -77,7 +90,10 @@ for item in [
|
|||||||
"better",
|
"better",
|
||||||
"worse",
|
"worse",
|
||||||
] + models:
|
] + models:
|
||||||
agg[item] = 0
|
agg[item] = {}
|
||||||
|
agg[item]["items"] = 0
|
||||||
|
agg[item]["better"] = 0
|
||||||
|
agg[item]["worse"] = 0
|
||||||
for dataset in dt:
|
for dataset in dt:
|
||||||
find_one = False
|
find_one = False
|
||||||
line = {"dataset": color + dataset[0]}
|
line = {"dataset": color + dataset[0]}
|
||||||
@@ -91,13 +107,15 @@ for dataset in dt:
|
|||||||
reference = record[10]
|
reference = record[10]
|
||||||
accuracy = record[5]
|
accuracy = record[5]
|
||||||
find_one = True
|
find_one = True
|
||||||
agg[model] += 1
|
agg[model]["items"] += 1
|
||||||
if accuracy > reference:
|
if accuracy > reference:
|
||||||
sign = "+"
|
sign = "+"
|
||||||
agg["better"] += 1
|
agg["better"]["items"] += 1
|
||||||
|
agg[model]["better"] += 1
|
||||||
else:
|
else:
|
||||||
sign = "-"
|
sign = "-"
|
||||||
agg["worse"] += 1
|
agg["worse"]["items"] += 1
|
||||||
|
agg[model]["worse"] += 1
|
||||||
item = f"{accuracy:9.7} {sign}"
|
item = f"{accuracy:9.7} {sign}"
|
||||||
line["reference"] = f"{reference:9.7}"
|
line["reference"] = f"{reference:9.7}"
|
||||||
line[model] = (
|
line[model] = (
|
||||||
|
@@ -47,6 +47,20 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]:
|
|||||||
help="Experiment: {gridsearch, gridbest, crossval, report_grid, "
|
help="Experiment: {gridsearch, gridbest, crossval, report_grid, "
|
||||||
"report_cross}",
|
"report_cross}",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"-k",
|
||||||
|
"--kernel",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"linear",
|
||||||
|
"poly",
|
||||||
|
"rbf",
|
||||||
|
"any",
|
||||||
|
],
|
||||||
|
required=False,
|
||||||
|
default="any",
|
||||||
|
help="Kernel: {linear, poly, rbf, any} only used in gridsearch",
|
||||||
|
)
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-d",
|
"-d",
|
||||||
"--dataset",
|
"--dataset",
|
||||||
@@ -88,6 +102,7 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]:
|
|||||||
args.normalize,
|
args.normalize,
|
||||||
args.standardize,
|
args.standardize,
|
||||||
args.excludeparams,
|
args.excludeparams,
|
||||||
|
args.kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -100,10 +115,15 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]:
|
|||||||
normalize,
|
normalize,
|
||||||
standardize,
|
standardize,
|
||||||
exclude_params,
|
exclude_params,
|
||||||
|
kernel,
|
||||||
) = parse_arguments()
|
) = parse_arguments()
|
||||||
|
|
||||||
experiment = Experiment(
|
experiment = Experiment(
|
||||||
random_state=1, model=model, host=host, set_of_files=set_of_files
|
random_state=1,
|
||||||
|
model=model,
|
||||||
|
host=host,
|
||||||
|
set_of_files=set_of_files,
|
||||||
|
kernel=kernel,
|
||||||
)
|
)
|
||||||
if experiment_type[0:6] == "report":
|
if experiment_type[0:6] == "report":
|
||||||
bd = (
|
bd = (
|
||||||
|
@@ -12,7 +12,12 @@ from .Sets import Datasets
|
|||||||
|
|
||||||
class Experiment:
|
class Experiment:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, random_state: int, model: str, host: str, set_of_files: str
|
self,
|
||||||
|
random_state: int,
|
||||||
|
model: str,
|
||||||
|
host: str,
|
||||||
|
set_of_files: str,
|
||||||
|
kernel: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._random_state = random_state
|
self._random_state = random_state
|
||||||
self._model_name = model
|
self._model_name = model
|
||||||
@@ -26,6 +31,7 @@ class Experiment:
|
|||||||
# used in gridsearch with ensembles to take best hyperparams of
|
# used in gridsearch with ensembles to take best hyperparams of
|
||||||
# base class or gridsearch these hyperparams as well
|
# base class or gridsearch these hyperparams as well
|
||||||
self._base_params = "any"
|
self._base_params = "any"
|
||||||
|
self._kernel = kernel
|
||||||
|
|
||||||
def set_base_params(self, base_params: str) -> None:
|
def set_base_params(self, base_params: str) -> None:
|
||||||
self._base_params = base_params
|
self._base_params = base_params
|
||||||
@@ -73,6 +79,12 @@ class Experiment:
|
|||||||
"""
|
"""
|
||||||
hyperparams = Hyperparameters(host=self._host, model=self._model_name)
|
hyperparams = Hyperparameters(host=self._host, model=self._model_name)
|
||||||
model = self._clf.get_model()
|
model = self._clf.get_model()
|
||||||
|
if self._kernel != "any":
|
||||||
|
# set parameters grid to only one kernel
|
||||||
|
if isinstance(self._clf, Models.Ensemble):
|
||||||
|
self._clf._base_model.select_params(self._kernel)
|
||||||
|
else:
|
||||||
|
self._clf.select_params(self._kernel)
|
||||||
hyperparameters = self._clf.get_parameters()
|
hyperparameters = self._clf.get_parameters()
|
||||||
grid_type = "gridsearch"
|
grid_type = "gridsearch"
|
||||||
if (
|
if (
|
||||||
@@ -111,7 +123,8 @@ class Experiment:
|
|||||||
model,
|
model,
|
||||||
return_train_score=True,
|
return_train_score=True,
|
||||||
param_grid=hyperparameters,
|
param_grid=hyperparameters,
|
||||||
n_jobs=-1,
|
n_jobs=1,
|
||||||
|
verbose=1,
|
||||||
)
|
)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
grid_search.fit(X, y)
|
grid_search.fit(X, y)
|
||||||
|
@@ -49,15 +49,14 @@ class ModelStree(ModelBase):
|
|||||||
gamma = [1e-1, 1, 1e1]
|
gamma = [1e-1, 1, 1e1]
|
||||||
max_features = [None, "auto"]
|
max_features = [None, "auto"]
|
||||||
split_criteria = ["impurity", "max_samples"]
|
split_criteria = ["impurity", "max_samples"]
|
||||||
self._param_grid = [
|
self._linear = {
|
||||||
{
|
|
||||||
"random_state": [self._random_state],
|
"random_state": [self._random_state],
|
||||||
"C": C,
|
"C": C,
|
||||||
"max_iter": max_iter,
|
"max_iter": max_iter,
|
||||||
"split_criteria": split_criteria,
|
"split_criteria": split_criteria,
|
||||||
"max_features": max_features,
|
"max_features": max_features,
|
||||||
},
|
}
|
||||||
{
|
self._poly = {
|
||||||
"random_state": [self._random_state],
|
"random_state": [self._random_state],
|
||||||
"kernel": ["rbf"],
|
"kernel": ["rbf"],
|
||||||
"C": C,
|
"C": C,
|
||||||
@@ -65,8 +64,8 @@ class ModelStree(ModelBase):
|
|||||||
"max_iter": max_iter,
|
"max_iter": max_iter,
|
||||||
"split_criteria": split_criteria,
|
"split_criteria": split_criteria,
|
||||||
"max_features": max_features,
|
"max_features": max_features,
|
||||||
},
|
}
|
||||||
{
|
self._rbf = {
|
||||||
"random_state": [self._random_state],
|
"random_state": [self._random_state],
|
||||||
"kernel": ["poly"],
|
"kernel": ["poly"],
|
||||||
"degree": [3, 5],
|
"degree": [3, 5],
|
||||||
@@ -75,9 +74,21 @@ class ModelStree(ModelBase):
|
|||||||
"max_iter": max_iter,
|
"max_iter": max_iter,
|
||||||
"split_criteria": split_criteria,
|
"split_criteria": split_criteria,
|
||||||
"max_features": max_features,
|
"max_features": max_features,
|
||||||
},
|
}
|
||||||
|
self._param_grid = [
|
||||||
|
self._linear,
|
||||||
|
self._poly,
|
||||||
|
self._rbf,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def select_params(self, kernel: str) -> None:
|
||||||
|
if kernel == "linear":
|
||||||
|
self._param_grid = [self._linear]
|
||||||
|
elif kernel == "poly":
|
||||||
|
self._param_grid = [self._poly]
|
||||||
|
else:
|
||||||
|
self._param_grid = [self._rbf]
|
||||||
|
|
||||||
|
|
||||||
class ModelSVC(ModelBase):
|
class ModelSVC(ModelBase):
|
||||||
def __init__(self, random_state: Optional[int] = None) -> None:
|
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||||
|
Reference in New Issue
Block a user