mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
#2 Cosmetic and style updates
This commit is contained in:
@@ -268,7 +268,8 @@ class Splitter:
|
|||||||
data = self.decision_criteria(data, node._y)
|
data = self.decision_criteria(data, node._y)
|
||||||
self._down = data > 0
|
self._down = data > 0
|
||||||
|
|
||||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
@staticmethod
|
||||||
|
def _distances(node: Snode, data: np.ndarray) -> np.array:
|
||||||
"""Compute distances of the samples to the hyperplane of the node
|
"""Compute distances of the samples to the hyperplane of the node
|
||||||
|
|
||||||
:param node: node containing the svm classifier
|
:param node: node containing the svm classifier
|
||||||
@@ -498,7 +499,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reorder_results(self, y: np.array, indices: np.array) -> np.array:
|
@staticmethod
|
||||||
|
def _reorder_results(y: np.array, indices: np.array) -> np.array:
|
||||||
"""Reorder an array based on the array of indices passed
|
"""Reorder an array based on the array of indices passed
|
||||||
|
|
||||||
:param y: data untidy
|
:param y: data untidy
|
||||||
@@ -579,7 +581,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
X, y = check_X_y(X, y)
|
X, y = check_X_y(X, y)
|
||||||
y_pred = self.predict(X).reshape(y.shape)
|
y_pred = self.predict(X).reshape(y.shape)
|
||||||
# Compute accuracy for each possible representation
|
# Compute accuracy for each possible representation
|
||||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
_, y_true, y_pred = _check_targets(y, y_pred)
|
||||||
check_consistent_length(y_true, y_pred, sample_weight)
|
check_consistent_length(y_true, y_pred, sample_weight)
|
||||||
score = y_true == y_pred
|
score = y_true == y_pred
|
||||||
return _weighted_sum(score, sample_weight, normalize=True)
|
return _weighted_sum(score, sample_weight, normalize=True)
|
||||||
|
@@ -13,8 +13,8 @@ class Splitter_test(unittest.TestCase):
|
|||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def build(
|
def build(
|
||||||
self,
|
|
||||||
clf=LinearSVC(),
|
clf=LinearSVC(),
|
||||||
min_samples_split=0,
|
min_samples_split=0,
|
||||||
splitter_type="random",
|
splitter_type="random",
|
||||||
|
@@ -67,9 +67,8 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf.fit(*load_dataset(self._random_state))
|
clf.fit(*load_dataset(self._random_state))
|
||||||
self._check_tree(clf.tree_)
|
self._check_tree(clf.tree_)
|
||||||
|
|
||||||
def _find_out(
|
@staticmethod
|
||||||
self, px: np.array, x_original: np.array, y_original
|
def _find_out(px: np.array, x_original: np.array, y_original) -> list:
|
||||||
) -> list:
|
|
||||||
"""Find the original values of y for a given array of samples
|
"""Find the original values of y for a given array of samples
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@@ -163,7 +162,8 @@ class Stree_test(unittest.TestCase):
|
|||||||
self.assertListEqual(expected, computed)
|
self.assertListEqual(expected, computed)
|
||||||
self.assertEqual(expected_string, str(clf))
|
self.assertEqual(expected_string, str(clf))
|
||||||
|
|
||||||
def test_is_a_sklearn_classifier(self):
|
@staticmethod
|
||||||
|
def test_is_a_sklearn_classifier():
|
||||||
import warnings
|
import warnings
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user