mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 15:36:01 +00:00
Add models oc1 and cart
This commit is contained in:
@@ -20,7 +20,7 @@ def parse_arguments() -> Tuple[str, str, str, str, str, bool, bool, dict]:
|
|||||||
"-m",
|
"-m",
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["stree", "adaBoost", "bagging", "odte"],
|
choices=["stree", "adaBoost", "bagging", "odte", "oc1", "cart"],
|
||||||
required=False,
|
required=False,
|
||||||
default="stree",
|
default="stree",
|
||||||
)
|
)
|
||||||
|
@@ -10,6 +10,7 @@ from sklearn.base import BaseEstimator # type: ignore
|
|||||||
from sklearn.svm import LinearSVC # type: ignore
|
from sklearn.svm import LinearSVC # type: ignore
|
||||||
from sklearn.tree import DecisionTreeClassifier # type: ignore
|
from sklearn.tree import DecisionTreeClassifier # type: ignore
|
||||||
from odte import Odte
|
from odte import Odte
|
||||||
|
from sklearn_oblique_tree.oblique import ObliqueTree
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(ABC):
|
class ModelBase(ABC):
|
||||||
@@ -39,6 +40,61 @@ class ModelBase(ABC):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCart(ModelBase):
|
||||||
|
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||||
|
self._clf = DecisionTreeClassifier()
|
||||||
|
super().__init__(random_state)
|
||||||
|
self._model_name = "cart"
|
||||||
|
self._linear = {
|
||||||
|
"random_state": [self._random_state],
|
||||||
|
"criterion": ["gini", "entropy"],
|
||||||
|
"splitter": ["best", "random"],
|
||||||
|
"max_features": [None, "sqrt", "auto", "log2"],
|
||||||
|
}
|
||||||
|
self._rbf = {}
|
||||||
|
self._poly = {}
|
||||||
|
self._param_grid = [
|
||||||
|
self._linear,
|
||||||
|
self._poly,
|
||||||
|
self._rbf,
|
||||||
|
]
|
||||||
|
|
||||||
|
def select_params(self, kernel: str) -> None:
|
||||||
|
if kernel == "linear":
|
||||||
|
self._param_grid = [self._linear]
|
||||||
|
elif kernel == "poly":
|
||||||
|
self._param_grid = [self._poly]
|
||||||
|
else:
|
||||||
|
self._param_grid = [self._rbf]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOc1(ModelBase):
|
||||||
|
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||||
|
self._clf = ObliqueTree(splitter="oc1")
|
||||||
|
super().__init__(random_state)
|
||||||
|
self._model_name = "oc1"
|
||||||
|
self._linear = {
|
||||||
|
"random_state": [self._random_state],
|
||||||
|
"number_of_restarts": [5, 10, 20, 50],
|
||||||
|
"max_perturbations": [2, 5, 10],
|
||||||
|
}
|
||||||
|
self._rbf = {}
|
||||||
|
self._poly = {}
|
||||||
|
self._param_grid = [
|
||||||
|
self._linear,
|
||||||
|
self._poly,
|
||||||
|
self._rbf,
|
||||||
|
]
|
||||||
|
|
||||||
|
def select_params(self, kernel: str) -> None:
|
||||||
|
if kernel == "linear":
|
||||||
|
self._param_grid = [self._linear]
|
||||||
|
elif kernel == "poly":
|
||||||
|
self._param_grid = [self._poly]
|
||||||
|
else:
|
||||||
|
self._param_grid = [self._rbf]
|
||||||
|
|
||||||
|
|
||||||
class ModelStree(ModelBase):
|
class ModelStree(ModelBase):
|
||||||
def __init__(self, random_state: Optional[int] = None) -> None:
|
def __init__(self, random_state: Optional[int] = None) -> None:
|
||||||
self._clf = Stree()
|
self._clf = Stree()
|
||||||
|
Reference in New Issue
Block a user