Add BaggingStree and BaggingWodt models

This commit is contained in:
2022-01-15 10:41:12 +01:00
parent 2c901d761c
commit d34d71fdef
4 changed files with 20 additions and 13 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -159,7 +159,6 @@ class Experiment:
self.title = title self.title = title
self.stratified = stratified == "1" self.stratified = stratified == "1"
self.stratified_class = StratifiedKFold if self.stratified else KFold self.stratified_class = StratifiedKFold if self.stratified else KFold
self.model = Models.get_model(model_name)
self.datasets = datasets self.datasets = datasets
dictionary = json.loads(hyperparams_dict) dictionary = json.loads(hyperparams_dict)
hyper = BestResults( hyper = BestResults(
@@ -185,8 +184,10 @@ class Experiment:
return self.output_file return self.output_file
def _build_classifier(self, random_state, hyperparameters): def _build_classifier(self, random_state, hyperparameters):
clf = self.model(random_state=random_state) self.model = Models.get_model(self.model_name, random_state)
clf = self.model
clf.set_params(**hyperparameters) clf.set_params(**hyperparameters)
clf.set_params(random_state=random_state)
return clf return clf
def _init_experiment(self): def _init_experiment(self):

View File

@@ -9,23 +9,27 @@ from odte import Odte
class Models: class Models:
@staticmethod @staticmethod
def get_model(name): def get_model(name, random_state=None):
if name == "STree": if name == "STree":
return Stree return Stree()
if name == "Cart": if name == "Cart":
return DecisionTreeClassifier return DecisionTreeClassifier()
if name == "ExtraTree": if name == "ExtraTree":
return ExtraTreeClassifier return ExtraTreeClassifier()
if name == "Wodt": if name == "Wodt":
return Wodt return Wodt()
if name == "SVC": if name == "SVC":
return SVC return SVC()
if name == "ODTE": if name == "ODTE":
return Odte return Odte()
if name == "Bagging": if name == "BaggingStree":
return BaggingClassifier clf = Stree(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "BaggingWodt":
clf = Wodt(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "RandomForest": if name == "RandomForest":
return RandomForestClassifier return RandomForestClassifier()
msg = f"No model recognized {name}" msg = f"No model recognized {name}"
if name in ("Stree", "stree"): if name in ("Stree", "stree"):
msg += ", did you mean STree?" msg += ", did you mean STree?"
@@ -43,7 +47,7 @@ class Models:
nodes = 0 nodes = 0
leaves = result.get_n_leaves() leaves = result.get_n_leaves()
depth = 0 depth = 0
elif name == "Bagging": elif name.startswith("Bagging"):
if hasattr(result.base_estimator_, "nodes_leaves"): if hasattr(result.base_estimator_, "nodes_leaves"):
nodes, leaves = list( nodes, leaves = list(
zip(*[x.nodes_leaves() for x in result.estimators_]) zip(*[x.nodes_leaves() for x in result.estimators_])