Files
benchmark/src/Models.py
2021-11-10 11:12:10 +01:00

47 lines
1.4 KiB
Python

from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from sklearn.svm import SVC
from stree import Stree
from wodt import TreeClassifier
from odte import Odte
class Models:
@staticmethod
def get_model(name):
if name == "STree":
return Stree
elif name == "Cart":
return DecisionTreeClassifier
elif name == "ExtraTree":
return ExtraTreeClassifier
elif name == "Wodt":
return TreeClassifier
elif name == "SVC":
return SVC
elif name == "ODTE":
return Odte
else:
msg = f"No model recognized {name}"
if name == "Stree" or name == "stree":
msg += ", did you mean STree?"
elif name == "odte" or name == "Odte":
msg += ", did you mean ODTE?"
raise ValueError(msg)
@staticmethod
def get_complexity(name, result):
if name == "Cart":
nodes = result.tree_.node_count
depth = result.tree_.max_depth
leaves = result.get_n_leaves()
elif name == "ExtraTree":
nodes = 0
leaves = result.get_n_leaves()
depth = 0
elif name == "SVC":
nodes = leaves = depth = 0
else:
nodes, leaves = result.nodes_leaves()
depth = result.depth_ if hasattr(result, "depth_") else 0
return nodes, leaves, depth