Add Random Forest

This commit is contained in:
2022-01-14 14:07:58 +01:00
parent bae3b676ec
commit f43622504c
9 changed files with 31 additions and 325 deletions

View File

@@ -1,7 +1,9 @@
from statistics import mean
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
from sklearn.svm import SVC
from stree import Stree
from wodt import TreeClassifier
from wodt import Wodt
from odte import Odte
@@ -15,11 +17,15 @@ class Models:
if name == "ExtraTree":
return ExtraTreeClassifier
if name == "Wodt":
return TreeClassifier
return Wodt
if name == "SVC":
return SVC
if name == "ODTE":
return Odte
if name == "Bagging":
return BaggingClassifier
if name == "RandomForest":
return RandomForestClassifier
msg = f"No model recognized {name}"
if name in ("Stree", "stree"):
msg += ", did you mean STree?"
@@ -37,6 +43,21 @@ class Models:
nodes = 0
leaves = result.get_n_leaves()
depth = 0
elif name=="Bagging":
if hasattr(result.base_estimator_, "nodes_leaves"):
nodes, leaves = list(zip(*[x.nodes_leaves() for x in result.estimators_]))
nodes, leaves = mean(nodes), mean(leaves)
depth = mean([x.depth_ for x in result.estimators_])
elif hasattr(result.base_estimator_, "tree_"):
nodes = mean([x.tree_.node_count for x in result.estimators_])
leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_])
else:
nodes = leaves=depth=0
elif name == "RandomForest":
leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_])
nodes = mean([x.tree_.node_count for x in result.estimators_])
elif name == "SVC":
nodes = leaves = depth = 0
else: