mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 08:55:53 +00:00
Add BaggingStree and BaggingWodt models
This commit is contained in:
@@ -9,23 +9,27 @@ from odte import Odte
|
||||
|
||||
class Models:
|
||||
@staticmethod
|
||||
def get_model(name):
|
||||
def get_model(name, random_state=None):
|
||||
if name == "STree":
|
||||
return Stree
|
||||
return Stree()
|
||||
if name == "Cart":
|
||||
return DecisionTreeClassifier
|
||||
return DecisionTreeClassifier()
|
||||
if name == "ExtraTree":
|
||||
return ExtraTreeClassifier
|
||||
return ExtraTreeClassifier()
|
||||
if name == "Wodt":
|
||||
return Wodt
|
||||
return Wodt()
|
||||
if name == "SVC":
|
||||
return SVC
|
||||
return SVC()
|
||||
if name == "ODTE":
|
||||
return Odte
|
||||
if name == "Bagging":
|
||||
return BaggingClassifier
|
||||
return Odte()
|
||||
if name == "BaggingStree":
|
||||
clf = Stree(random_state=random_state)
|
||||
return BaggingClassifier(base_estimator=clf)
|
||||
if name == "BaggingWodt":
|
||||
clf = Wodt(random_state=random_state)
|
||||
return BaggingClassifier(base_estimator=clf)
|
||||
if name == "RandomForest":
|
||||
return RandomForestClassifier
|
||||
return RandomForestClassifier()
|
||||
msg = f"No model recognized {name}"
|
||||
if name in ("Stree", "stree"):
|
||||
msg += ", did you mean STree?"
|
||||
@@ -43,7 +47,7 @@ class Models:
|
||||
nodes = 0
|
||||
leaves = result.get_n_leaves()
|
||||
depth = 0
|
||||
elif name == "Bagging":
|
||||
elif name.startswith("Bagging"):
|
||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
||||
nodes, leaves = list(
|
||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||
|
Reference in New Issue
Block a user