Add models oc1 and cart

This commit is contained in:
2021-02-28 18:46:58 +01:00
parent 091edbe3dc
commit df42c0df74
2 changed files with 57 additions and 1 deletions

View File

@@ -10,6 +10,7 @@ from sklearn.base import BaseEstimator # type: ignore
from sklearn.svm import LinearSVC # type: ignore
from sklearn.tree import DecisionTreeClassifier # type: ignore
from odte import Odte
from sklearn_oblique_tree.oblique import ObliqueTree
class ModelBase(ABC):
@@ -39,6 +40,61 @@ class ModelBase(ABC):
return result
class ModelCart(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = DecisionTreeClassifier()
super().__init__(random_state)
self._model_name = "cart"
self._linear = {
"random_state": [self._random_state],
"criterion": ["gini", "entropy"],
"splitter": ["best", "random"],
"max_features": [None, "sqrt", "auto", "log2"],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelOc1(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = ObliqueTree(splitter="oc1")
super().__init__(random_state)
self._model_name = "oc1"
self._linear = {
"random_state": [self._random_state],
"number_of_restarts": [5, 10, 20, 50],
"max_perturbations": [2, 5, 10],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelStree(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = Stree()