Add BaggingStree and BaggingWodt models

This commit is contained in:
2022-01-15 10:41:12 +01:00
parent 2c901d761c
commit d34d71fdef
4 changed files with 20 additions and 13 deletions

View File

@@ -9,23 +9,27 @@ from odte import Odte
class Models:
@staticmethod
def get_model(name):
def get_model(name, random_state=None):
if name == "STree":
return Stree
return Stree()
if name == "Cart":
return DecisionTreeClassifier
return DecisionTreeClassifier()
if name == "ExtraTree":
return ExtraTreeClassifier
return ExtraTreeClassifier()
if name == "Wodt":
return Wodt
return Wodt()
if name == "SVC":
return SVC
return SVC()
if name == "ODTE":
return Odte
if name == "Bagging":
return BaggingClassifier
return Odte()
if name == "BaggingStree":
clf = Stree(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "BaggingWodt":
clf = Wodt(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
if name == "RandomForest":
return RandomForestClassifier
return RandomForestClassifier()
msg = f"No model recognized {name}"
if name in ("Stree", "stree"):
msg += ", did you mean STree?"
@@ -43,7 +47,7 @@ class Models:
nodes = 0
leaves = result.get_n_leaves()
depth = 0
elif name == "Bagging":
elif name.startswith("Bagging"):
if hasattr(result.base_estimator_, "nodes_leaves"):
nodes, leaves = list(
zip(*[x.nodes_leaves() for x in result.estimators_])