Add wodt clf

Add execution results of RaF, RoF and RRoF
Fix fit time in database records
This commit is contained in:
2021-03-10 01:37:00 +01:00
parent f52565b2a5
commit d4cfe77b18
14 changed files with 782 additions and 9 deletions

View File

@@ -345,8 +345,8 @@ class Outcomes(BD):
float(results["test_score"].std()),
],
[
float(results["score_time"].mean()),
float(results["score_time"].std()),
float(results["fit_time"].mean()),
float(results["fit_time"].std()),
],
parameters,
)
@@ -441,8 +441,8 @@ class Hyperparameters(BD):
float(outcomes["test_score_std"]),
]
time_spent = [
float(outcomes["score_time"]),
float(outcomes["score_time_std"]),
float(outcomes["fit_time"]),
float(outcomes["fit_time_std"]),
]
self.mirror(
grid_type,

View File

@@ -11,6 +11,7 @@ from sklearn.svm import LinearSVC # type: ignore
from sklearn.tree import DecisionTreeClassifier # type: ignore
from odte import Odte
from sklearn_oblique_tree.oblique import ObliqueTree
from wodt import TreeClassifier
class ModelBase(ABC):
@@ -95,6 +96,31 @@ class ModelOc1(ModelBase):
self._param_grid = [self._rbf]
class ModelWodt(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = TreeClassifier()
super().__init__(random_state)
self._model_name = "wodt"
self._linear = {
"random_state": [self._random_state],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelStree(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = Stree()