mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 08:55:53 +00:00
Add Random Forest
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from statistics import mean
|
||||
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
|
||||
from sklearn.svm import SVC
|
||||
from stree import Stree
|
||||
from wodt import TreeClassifier
|
||||
from wodt import Wodt
|
||||
from odte import Odte
|
||||
|
||||
|
||||
@@ -15,11 +17,15 @@ class Models:
|
||||
if name == "ExtraTree":
|
||||
return ExtraTreeClassifier
|
||||
if name == "Wodt":
|
||||
return TreeClassifier
|
||||
return Wodt
|
||||
if name == "SVC":
|
||||
return SVC
|
||||
if name == "ODTE":
|
||||
return Odte
|
||||
if name == "Bagging":
|
||||
return BaggingClassifier
|
||||
if name == "RandomForest":
|
||||
return RandomForestClassifier
|
||||
msg = f"No model recognized {name}"
|
||||
if name in ("Stree", "stree"):
|
||||
msg += ", did you mean STree?"
|
||||
@@ -37,6 +43,21 @@ class Models:
|
||||
nodes = 0
|
||||
leaves = result.get_n_leaves()
|
||||
depth = 0
|
||||
elif name=="Bagging":
|
||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
||||
nodes, leaves = list(zip(*[x.nodes_leaves() for x in result.estimators_]))
|
||||
nodes, leaves = mean(nodes), mean(leaves)
|
||||
depth = mean([x.depth_ for x in result.estimators_])
|
||||
elif hasattr(result.base_estimator_, "tree_"):
|
||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||
depth = mean([x.get_depth() for x in result.estimators_])
|
||||
else:
|
||||
nodes = leaves=depth=0
|
||||
elif name == "RandomForest":
|
||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||
depth = mean([x.get_depth() for x in result.estimators_])
|
||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
||||
elif name == "SVC":
|
||||
nodes = leaves = depth = 0
|
||||
else:
|
||||
|
Reference in New Issue
Block a user