mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 16:35:54 +00:00
Add AdaBoostStree model
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
from statistics import mean
|
||||
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
|
||||
from sklearn.ensemble import (
|
||||
RandomForestClassifier,
|
||||
BaggingClassifier,
|
||||
AdaBoostClassifier,
|
||||
)
|
||||
from sklearn.svm import SVC
|
||||
from stree import Stree
|
||||
from wodt import Wodt
|
||||
@@ -28,6 +32,9 @@ class Models:
|
||||
if name == "BaggingWodt":
|
||||
clf = Wodt(random_state=random_state)
|
||||
return BaggingClassifier(base_estimator=clf)
|
||||
if name == "AdaBoostStree":
|
||||
clf = Stree(random_state=random_state)
|
||||
return AdaBoostClassifier(base_estimator=clf)
|
||||
if name == "RandomForest":
|
||||
return RandomForestClassifier()
|
||||
msg = f"No model recognized {name}"
|
||||
@@ -47,7 +54,7 @@ class Models:
|
||||
nodes = 0
|
||||
leaves = result.get_n_leaves()
|
||||
depth = 0
|
||||
elif name.startswith("Bagging"):
|
||||
elif name.startswith("Bagging") or name.startswith("AdaBoost"):
|
||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
||||
nodes, leaves = list(
|
||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||
|
Reference in New Issue
Block a user