#2 Cosmetic and style updates

This commit is contained in:
2020-06-15 11:09:11 +02:00
parent 736ab7ef20
commit 9334951d1b
3 changed files with 10 additions and 8 deletions

View File

@@ -268,7 +268,8 @@ class Splitter:
data = self.decision_criteria(data, node._y) data = self.decision_criteria(data, node._y)
self._down = data > 0 self._down = data > 0
def _distances(self, node: Snode, data: np.ndarray) -> np.array: @staticmethod
def _distances(node: Snode, data: np.ndarray) -> np.array:
"""Compute distances of the samples to the hyperplane of the node """Compute distances of the samples to the hyperplane of the node
:param node: node containing the svm classifier :param node: node containing the svm classifier
@@ -498,7 +499,8 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
) )
def _reorder_results(self, y: np.array, indices: np.array) -> np.array: @staticmethod
def _reorder_results(y: np.array, indices: np.array) -> np.array:
"""Reorder an array based on the array of indices passed """Reorder an array based on the array of indices passed
:param y: data untidy :param y: data untidy
@@ -579,7 +581,7 @@ class Stree(BaseEstimator, ClassifierMixin):
X, y = check_X_y(X, y) X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape) y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation # Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y, y_pred) _, y_true, y_pred = _check_targets(y, y_pred)
check_consistent_length(y_true, y_pred, sample_weight) check_consistent_length(y_true, y_pred, sample_weight)
score = y_true == y_pred score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True) return _weighted_sum(score, sample_weight, normalize=True)

View File

@@ -13,8 +13,8 @@ class Splitter_test(unittest.TestCase):
self._random_state = 1 self._random_state = 1
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod
def build( def build(
self,
clf=LinearSVC(), clf=LinearSVC(),
min_samples_split=0, min_samples_split=0,
splitter_type="random", splitter_type="random",

View File

@@ -67,9 +67,8 @@ class Stree_test(unittest.TestCase):
clf.fit(*load_dataset(self._random_state)) clf.fit(*load_dataset(self._random_state))
self._check_tree(clf.tree_) self._check_tree(clf.tree_)
def _find_out( @staticmethod
self, px: np.array, x_original: np.array, y_original def _find_out(px: np.array, x_original: np.array, y_original) -> list:
) -> list:
"""Find the original values of y for a given array of samples """Find the original values of y for a given array of samples
Arguments: Arguments:
@@ -163,7 +162,8 @@ class Stree_test(unittest.TestCase):
self.assertListEqual(expected, computed) self.assertListEqual(expected, computed)
self.assertEqual(expected_string, str(clf)) self.assertEqual(expected_string, str(clf))
def test_is_a_sklearn_classifier(self): @staticmethod
def test_is_a_sklearn_classifier():
import warnings import warnings
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning