Add models oc1 and cart

This commit is contained in:
2021-02-28 18:46:58 +01:00
parent 091edbe3dc
commit df42c0df74
2 changed files with 57 additions and 1 deletions

View File

@@ -20,7 +20,7 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]:
"-m", "-m",
"--model", "--model",
type=str, type=str,
choices=["stree", "adaBoost", "bagging", "odte"], choices=["stree", "adaBoost", "bagging", "odte", "oc1", "cart"],
required=False, required=False,
default="stree", default="stree",
) )

View File

@@ -10,6 +10,7 @@ from sklearn.base import BaseEstimator # type: ignore
from sklearn.svm import LinearSVC # type: ignore from sklearn.svm import LinearSVC # type: ignore
from sklearn.tree import DecisionTreeClassifier # type: ignore from sklearn.tree import DecisionTreeClassifier # type: ignore
from odte import Odte from odte import Odte
from sklearn_oblique_tree.oblique import ObliqueTree
class ModelBase(ABC): class ModelBase(ABC):
@@ -39,6 +40,61 @@ class ModelBase(ABC):
return result return result
class ModelCart(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = DecisionTreeClassifier()
super().__init__(random_state)
self._model_name = "cart"
self._linear = {
"random_state": [self._random_state],
"criterion": ["gini", "entropy"],
"splitter": ["best", "random"],
"max_features": [None, "sqrt", "auto", "log2"],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelOc1(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = ObliqueTree(splitter="oc1")
super().__init__(random_state)
self._model_name = "oc1"
self._linear = {
"random_state": [self._random_state],
"number_of_restarts": [5, 10, 20, 50],
"max_perturbations": [2, 5, 10],
}
self._rbf = {}
self._poly = {}
self._param_grid = [
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelStree(ModelBase): class ModelStree(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None: def __init__(self, random_state: Optional[int] = None) -> None:
self._clf = Stree() self._clf = Stree()