Add wodt clf

Add execution results of RaF, RoF and RRoF
Fix fit time in database records
This commit is contained in:
2021-03-10 01:37:00 +01:00
parent f52565b2a5
commit d4cfe77b18
14 changed files with 782 additions and 9 deletions

View File

@@ -11,6 +11,7 @@ from sklearn.svm import LinearSVC # type: ignore
from sklearn.tree import DecisionTreeClassifier # type: ignore
from odte import Odte
from sklearn_oblique_tree.oblique import ObliqueTree
from wodt import TreeClassifier
class ModelBase(ABC):
@@ -95,6 +96,31 @@ class ModelOc1(ModelBase):
self._param_grid = [self._rbf]
class ModelWodt(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = TreeClassifier()
super().__init__(random_state)
self._model_name = "wodt"
self._linear = {
"random_state": [self._random_state],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelStree(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = Stree()