mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 23:46:03 +00:00
Add wodt clf
Add execution results of RaF, RoF and RRoF Fix fit time in database records
This commit is contained in:
@@ -11,6 +11,7 @@ from sklearn.svm import LinearSVC # type: ignore
|
||||
from sklearn.tree import DecisionTreeClassifier # type: ignore
|
||||
from odte import Odte
|
||||
from sklearn_oblique_tree.oblique import ObliqueTree
|
||||
from wodt import TreeClassifier
|
||||
|
||||
|
||||
class ModelBase(ABC):
|
||||
@@ -95,6 +96,31 @@ class ModelOc1(ModelBase):
|
||||
self._param_grid = [self._rbf]
|
||||
|
||||
|
||||
class ModelWodt(ModelBase):
|
||||
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||
self._clf = TreeClassifier()
|
||||
super().__init__(random_state)
|
||||
self._model_name = "wodt"
|
||||
self._linear = {
|
||||
"random_state": [self._random_state],
|
||||
}
|
||||
self._rbf = {}
|
||||
self._poly = {}
|
||||
self._param_grid = [
|
||||
self._linear,
|
||||
self._poly,
|
||||
self._rbf,
|
||||
]
|
||||
|
||||
def select_params(self, kernel: str) -> None:
|
||||
if kernel == "linear":
|
||||
self._param_grid = [self._linear]
|
||||
elif kernel == "poly":
|
||||
self._param_grid = [self._poly]
|
||||
else:
|
||||
self._param_grid = [self._rbf]
|
||||
|
||||
|
||||
class ModelStree(ModelBase):
|
||||
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||
self._clf = Stree()
|
||||
|
Reference in New Issue
Block a user