Add BaggingStree and BaggingWodt models

This commit is contained in:
2022-01-15 10:41:12 +01:00
parent 2c901d761c
commit d34d71fdef
4 changed files with 20 additions and 13 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -159,7 +159,6 @@ class Experiment:
self.title = title
self.stratified = stratified == "1"
self.stratified_class = StratifiedKFold if self.stratified else KFold
self.model = Models.get_model(model_name)
self.datasets = datasets
dictionary = json.loads(hyperparams_dict)
hyper = BestResults(
@@ -185,8 +184,10 @@ class Experiment:
return self.output_file
def _build_classifier(self, random_state, hyperparameters):
clf = self.model(random_state=random_state)
self.model = Models.get_model(self.model_name, random_state)
clf = self.model
clf.set_params(**hyperparameters)
clf.set_params(random_state=random_state)
return clf
def _init_experiment(self):

View File

@@ -9,23 +9,27 @@ from odte import Odte
class Models:
@staticmethod
def get_model(name):
def get_model(name, random_state=None):
if name == "STree":
return Stree
return Stree()
if name == "Cart":
return DecisionTreeClassifier
return DecisionTreeClassifier()
if name == "ExtraTree":
return ExtraTreeClassifier
return ExtraTreeClassifier()
if name == "Wodt":
return Wodt
return Wodt()
if name == "SVC":
return SVC
return SVC()
if name == "ODTE":
return Odte
if name == "Bagging":
return BaggingClassifier
return Odte()
if name == "BaggingStree":
clf = Stree(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "BaggingWodt":
clf = Wodt(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "RandomForest":
return RandomForestClassifier
return RandomForestClassifier()
msg = f"No model recognized {name}"
if name in ("Stree", "stree"):
msg += ", did you mean STree?"
@@ -43,7 +47,7 @@ class Models:
nodes = 0
leaves = result.get_n_leaves()
depth = 0
elif name == "Bagging":
elif name.startswith("Bagging"):
if hasattr(result.base_estimator_, "nodes_leaves"):
nodes, leaves = list(
zip(*[x.nodes_leaves() for x in result.estimators_])