mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
Implement ovo strategy (#37)
* Implement ovo strategy * Set ovo strategy as default * Add kernel liblinear with LinearSVC classifier * Fix weak test
This commit is contained in:
committed by
GitHub
parent
5cef0f4875
commit
4f04e72670
@@ -155,6 +155,10 @@ class Siterator:
|
||||
self._stack = []
|
||||
self._push(tree)
|
||||
|
||||
def __iter__(self):
|
||||
# To complete the iterator interface
|
||||
return self
|
||||
|
||||
def _push(self, node: Snode):
|
||||
if node is not None:
|
||||
self._stack.append(node)
|
||||
@@ -373,16 +377,17 @@ class Splitter:
|
||||
tuple
|
||||
indices of the features selected
|
||||
"""
|
||||
# No feature reduction
|
||||
if dataset.shape[1] == max_features:
|
||||
# No feature reduction applies
|
||||
return tuple(range(dataset.shape[1]))
|
||||
# Random feature reduction
|
||||
if self._feature_select == "random":
|
||||
features_sets = self._generate_spaces(
|
||||
dataset.shape[1], max_features
|
||||
)
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
# return the KBest features
|
||||
if self._feature_select == "best":
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
@@ -569,6 +574,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
min_samples_split: int = 0,
|
||||
max_features=None,
|
||||
splitter: str = "random",
|
||||
multiclass_strategy: str = "ovo",
|
||||
normalize: bool = False,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
@@ -585,6 +591,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.criterion = criterion
|
||||
self.splitter = splitter
|
||||
self.normalize = normalize
|
||||
self.multiclass_strategy = multiclass_strategy
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
@@ -629,7 +636,23 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
f"Maximum depth has to be greater than 1... got (max_depth=\
|
||||
{self.max_depth})"
|
||||
)
|
||||
kernels = ["linear", "rbf", "poly", "sigmoid"]
|
||||
if self.multiclass_strategy not in ["ovr", "ovo"]:
|
||||
raise ValueError(
|
||||
"mutliclass_strategy has to be either ovr or ovo"
|
||||
f" but got {self.multiclass_strategy}"
|
||||
)
|
||||
if self.multiclass_strategy == "ovo":
|
||||
if self.kernel == "liblinear":
|
||||
raise ValueError(
|
||||
"The kernel liblinear is incompatible with ovo "
|
||||
"multiclass_strategy"
|
||||
)
|
||||
if self.split_criteria == "max_samples":
|
||||
raise ValueError(
|
||||
"The multiclass_strategy 'ovo' is incompatible with "
|
||||
"split_criteria 'max_samples'"
|
||||
)
|
||||
kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"]
|
||||
if self.kernel not in kernels:
|
||||
raise ValueError(f"Kernel {self.kernel} not in {kernels}")
|
||||
check_classification_targets(y)
|
||||
@@ -749,7 +772,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
C=self.C,
|
||||
tol=self.tol,
|
||||
)
|
||||
if self.kernel == "linear"
|
||||
if self.kernel == "liblinear"
|
||||
else SVC(
|
||||
kernel=self.kernel,
|
||||
max_iter=self.max_iter,
|
||||
@@ -758,6 +781,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
gamma=self.gamma,
|
||||
degree=self.degree,
|
||||
random_state=self.random_state,
|
||||
decision_function_shape=self.multiclass_strategy,
|
||||
)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user