mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
#3 Complete multiclass in Stree
Add multiclass dimensions management in distances method Add gamma hyperparameter for non linear kernels
This commit is contained in:
@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
random_state: int = None,
|
||||
max_depth: int = None,
|
||||
tol: float = 1e-4,
|
||||
gamma="scale",
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.random_state = random_state
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
self.gamma = gamma
|
||||
self.min_samples_split = min_samples_split
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
return {"binary_only": True, "requires_y": True}
|
||||
|
||||
def _linear_function(self, data: np.array, node: Snode) -> np.array:
|
||||
"""Compute the distance of set of samples to a hyperplane, in
|
||||
multiclass classification it should compute the distance to a
|
||||
hyperplane of each class
|
||||
|
||||
:param data: dataset of samples
|
||||
:type data: np.array shape(m, n)
|
||||
:param node: the node that contains the hyperplance coefficients
|
||||
:type node: Snode shape(1, n)
|
||||
:return: array of distances of each sample to the hyperplane
|
||||
:rtype: np.array
|
||||
"""
|
||||
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._clf.intercept_[0]
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
"""Split an array in two based on indices passed as down and its complement
|
||||
|
||||
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: list
|
||||
"""
|
||||
up = ~down
|
||||
print(self.kernel, up.shape, down.shape)
|
||||
return (
|
||||
origin[up[:, 0]] if any(up) else None,
|
||||
origin[down[:, 0]] if any(down) else None,
|
||||
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
the hyperplane of the node
|
||||
:rtype: np.array
|
||||
"""
|
||||
return np.expand_dims(node._clf.decision_function(data), 1)
|
||||
res = node._clf.decision_function(data)
|
||||
if res.ndim == 1:
|
||||
return np.expand_dims(res, 1)
|
||||
elif res.shape[1] > 1:
|
||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||
return res
|
||||
|
||||
def _split_criteria(self, data: np.array) -> np.array:
|
||||
"""Set the criteria to split arrays
|
||||
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
""" Select the correct classifier for the node
|
||||
""" Build the correct classifier for the node
|
||||
"""
|
||||
|
||||
return (
|
||||
LinearSVC(
|
||||
max_iter=self.max_iter,
|
||||
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
max_iter=self.max_iter,
|
||||
tol=self.tol,
|
||||
C=self.C,
|
||||
gamma=self.gamma,
|
||||
)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user