From df42c0df7471a62eafbc803494fe0d07de9fc032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 28 Feb 2021 18:46:58 +0100 Subject: [PATCH] Add models oc1 and cart --- experiment.py | 2 +- experimentation/Models.py | 56 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/experiment.py b/experiment.py index 8ac1401..654f783 100644 --- a/experiment.py +++ b/experiment.py @@ -20,7 +20,7 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]: "-m", "--model", type=str, - choices=["stree", "adaBoost", "bagging", "odte"], + choices=["stree", "adaBoost", "bagging", "odte", "oc1", "cart"], required=False, default="stree", ) diff --git a/experimentation/Models.py b/experimentation/Models.py index 274db4e..1956f79 100644 --- a/experimentation/Models.py +++ b/experimentation/Models.py @@ -10,6 +10,7 @@ from sklearn.base import BaseEstimator # type: ignore from sklearn.svm import LinearSVC # type: ignore from sklearn.tree import DecisionTreeClassifier # type: ignore from odte import Odte +from sklearn_oblique_tree.oblique import ObliqueTree class ModelBase(ABC): @@ -39,6 +40,61 @@ class ModelBase(ABC): return result +class ModelCart(ModelBase): + def __init__(self, random_state: Optional[int] = None) -> None: + self._clf = DecisionTreeClassifier() + super().__init__(random_state) + self._model_name = "cart" + self._linear = { + "random_state": [self._random_state], + "criterion": ["gini", "entropy"], + "splitter": ["best", "random"], + "max_features": [None, "sqrt", "auto", "log2"], + } + self._rbf = {} + self._poly = {} + self._param_grid = [ + self._linear, + self._poly, + self._rbf, + ] + + def select_params(self, kernel: str) -> None: + if kernel == "linear": + self._param_grid = [self._linear] + elif kernel == "poly": + self._param_grid = [self._poly] + else: + self._param_grid = [self._rbf] + + +class ModelOc1(ModelBase): + def __init__(self, random_state: Optional[int] = None) -> None: + self._clf = ObliqueTree(splitter="oc1") + super().__init__(random_state) + self._model_name = "oc1" + self._linear = { + "random_state": [self._random_state], + "number_of_restarts": [5, 10, 20, 50], + "max_perturbations": [2, 5, 10], + } + self._rbf = {} + self._poly = {} + self._param_grid = [ + self._linear, + self._poly, + self._rbf, + ] + + def select_params(self, kernel: str) -> None: + if kernel == "linear": + self._param_grid = [self._linear] + elif kernel == "poly": + self._param_grid = [self._poly] + else: + self._param_grid = [self._rbf] + + class ModelStree(ModelBase): def __init__(self, random_state: Optional[int] = None) -> None: self._clf = Stree()