diff --git a/setup.py b/setup.py index e0bbd04..b31585f 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ import setuptools -__version__ = "0.9rc2" +__version__ = "0.9rc3" __author__ = "Ricardo Montañana Gómez" def readme(): diff --git a/stree/Strees.py b/stree/Strees.py index 99fbba3..b938ad5 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -15,6 +15,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.svm import LinearSVC from sklearn.utils.validation import check_X_y, check_array, check_is_fitted + class Snode: def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str): self._clf = clf @@ -22,7 +23,7 @@ class Snode: self._interceptor = 0. if clf is None else clf.intercept_ self._title = title self._belief = 0. # belief of the prediction in a leaf node based on samples - # Only store dataset in Testing + # Only store dataset in Testing self._X = X if os.environ.get('TESTING', 'NS') != 'NS' else None self._y = y self._down = None @@ -97,24 +98,27 @@ class Siterator: self._push(node.get_down()) return node + class Stree(BaseEstimator, ClassifierMixin): """ """ + __folder = 'data/' def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False): - self._max_iter = max_iter - self._C = C - self._random_state = random_state - self._tree = None - self.__folder = 'data/' - self.__use_predictions = use_predictions - self.__trained = False - self.__proba = False + self.max_iter = max_iter + self.C = C + self.random_state = random_state + self.use_predictions = use_predictions def get_params(self, deep=True): """Get dict with hyperparameters and its values to accomplish sklearn rules """ - return {"C": self._C, "random_state": self._random_state, 'max_iter': self._max_iter} + return { + 'C': self.C, + 'random_state': self.random_state, + 'max_iter': self.max_iter, + 'use_predictions': self.use_predictions + } def set_params(self, **parameters): """Set hyperparmeters as specified by sklearn, needed in Gridsearchs @@ -123,12 +127,16 @@ class Stree(BaseEstimator, ClassifierMixin): setattr(self, parameter, value) return self + # Added binary_only tag as required by sklearn check_estimator + def _more_tags(self): + return {'binary_only': True} + def _linear_function(self, data: np.array, node: Snode) -> np.array: coef = node._vector[0, :].reshape(-1, data.shape[1]) return data.dot(coef.T) + node._interceptor[0] def _split_data(self, node: Snode, data: np.ndarray, indices: np.ndarray) -> list: - if self.__use_predictions: + if self.use_predictions: yp = node._clf.predict(data) down = (yp == 1).reshape(-1, 1) res = np.expand_dims(node._clf.decision_function(data), 1) @@ -147,11 +155,16 @@ class Stree(BaseEstimator, ClassifierMixin): return [data_up, indices_up, data_down, indices_down, res_up, res_down] def fit(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> 'Stree': - X, y = check_X_y(X, y.ravel()) + from sklearn.utils.multiclass import check_classification_targets + if type(y).__name__ == 'np.ndarray': + y = y.ravel() + X, y = check_X_y(X, y) + self.classes_ = np.unique(y) + self.n_iter_ = self.max_iter + check_classification_targets(y) self.n_features_in_ = X.shape[1] - self._tree = self.train(X, y.ravel(), title) + self.tree_ = self.train(X, y.ravel(), title) self._build_predictor() - self.__trained = True return self def _build_predictor(self): @@ -165,15 +178,15 @@ class Stree(BaseEstimator, ClassifierMixin): run_tree(node.get_down()) run_tree(node.get_up()) - run_tree(self._tree) + run_tree(self.tree_) def train(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> Snode: if np.unique(y).shape[0] == 1: # only 1 class => pure dataset return Snode(None, X, y, title + ', ') # Train the model - clf = LinearSVC(max_iter=self._max_iter, C=self._C, - random_state=self._random_state) + clf = LinearSVC(max_iter=self.max_iter, C=self.C, + random_state=self.random_state) clf.fit(X, y) tree = Snode(clf, X, y, title) X_U, y_u, X_D, y_d, _, _ = self._split_data(tree, X, y) @@ -184,8 +197,13 @@ class Stree(BaseEstimator, ClassifierMixin): tree.set_down(self.train(X_D, y_d, title + ' - Down')) return tree - def _reorder_results(self, y: np.array, indices: np.array) -> np.array: - y_ordered = np.zeros(y.shape, dtype=int if y.ndim == 1 else float) + def _reorder_results(self, y: np.array, indices: np.array, proba=False) -> np.array: + if proba: + # if predict_proba return np.array of floats + y_ordered = np.zeros(y.shape, dtype=float) + else: + # return array of same type given in y + y_ordered = y.copy() indices = indices.astype(int) for i, index in enumerate(indices): y_ordered[index] = y[i] @@ -205,17 +223,15 @@ class Stree(BaseEstimator, ClassifierMixin): return np.append(k, m), np.append(l, n) # sklearn check - check_is_fitted(self) + check_is_fitted(self, ['tree_']) # Input validation X = check_array(X) # setup prediction & make it happen indices = np.arange(X.shape[0]) - return self._reorder_results(*predict_class(X, indices, self._tree)) + return self._reorder_results(*predict_class(X, indices, self.tree_)).ravel() def predict_proba(self, X: np.array) -> np.array: - """Computes an approximation of the probability of samples belonging to class 1 - (nothing more, nothing less) - + """Computes an approximation of the probability of samples belonging to class 0 and 1 :param X: dataset :type X: np.array """ @@ -247,29 +263,31 @@ class Stree(BaseEstimator, ClassifierMixin): return np.append(k, m), np.append(l, n) # sklearn check - check_is_fitted(self) + check_is_fitted(self, ['tree_']) # Input validation X = check_array(X) # setup prediction & make it happen indices = np.arange(X.shape[0]) - result, indices = predict_class(X, indices, [], self._tree) + empty_dist = np.empty((X.shape[0], 1), dtype=float) + result, indices = predict_class(X, indices, empty_dist, self.tree_) result = result.reshape(X.shape[0], 2) # Turn distances to hyperplane into probabilities based on fitting distances # of samples to its hyperplane that classified them, to the sigmoid function - result[:, 1] = 1 / (1 + np.exp(-result[:, 1])) - return self._reorder_results(result, indices) + result[:, 1] = 1 / (1 + np.exp(-result[:, 1])) # Probability of being 1 + result[:, 0] = 1 - result[:, 1] # Probability of being 0 + return self._reorder_results(result, indices, proba=True) def score(self, X: np.array, y: np.array) -> float: """Return accuracy """ - if not self.__trained: - self.fit(X, y) + # sklearn check + check_is_fitted(self) yp = self.predict(X).reshape(y.shape) right = (yp == y).astype(int) return np.sum(right) / len(y) def __iter__(self): - return Siterator(self._tree) + return Siterator(self.tree_) def __str__(self) -> str: output = '' @@ -305,7 +323,5 @@ class Stree(BaseEstimator, ClassifierMixin): if not os.path.isdir(self.__folder): os.mkdir(self.__folder) with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog: - self._save_datasets(self._tree, catalog, 1) - - + self._save_datasets(self.tree_, catalog, 1) diff --git a/stree/Strees_grapher.py b/stree/Strees_grapher.py index 19d2516..6d0e46a 100644 --- a/stree/Strees_grapher.py +++ b/stree/Strees_grapher.py @@ -143,7 +143,7 @@ class Stree_grapher(Stree): self._pca = PCA(n_components=3) X = self._pca.fit_transform(X) res = super().fit(X, y) - self._tree_gr = self._copy_tree(self._tree) + self._tree_gr = self._copy_tree(self.tree_) self._fitted = True return res diff --git a/stree/tests/Strees_test.py b/stree/tests/Strees_test.py index 5d52fe9..ef9053d 100644 --- a/stree/tests/Strees_test.py +++ b/stree/tests/Strees_test.py @@ -71,7 +71,7 @@ class Stree_test(unittest.TestCase): def test_build_tree(self): """Check if the tree is built the same way as predictions of models """ - self._check_tree(self._clf._tree) + self._check_tree(self._clf.tree_) def _get_file_data(self, file_name: str) -> tuple: """Return X, y from data, y is the last column in array @@ -145,12 +145,14 @@ class Stree_test(unittest.TestCase): """ # Element 28 has a different prediction than the truth decimals = 5 + prob = 0.29026400766 X, y = self._get_Xy() yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1])) - self.assertEqual(0, yp[0:, 0]) + self.assertEqual(np.round(1 - prob, decimals), np.round(yp[0:, 0], decimals)) self.assertEqual(1, y[28]) + self.assertAlmostEqual( - round(0.29026400766, decimals), + round(prob, decimals), round(yp[0, 1], decimals), decimals ) @@ -161,7 +163,7 @@ class Stree_test(unittest.TestCase): decimals = 5 X, y = self._get_Xy() yp = self._clf.predict_proba(X[:num, :]) - self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist()) + self.assertListEqual(y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist()) expected_proba = [0.88395641, 0.36746962, 0.84158767, 0.34106833, 0.14269291, 0.85193236, 0.29876058, 0.7282164, 0.85958616, 0.89517877, 0.99745224, 0.18860349, 0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355, @@ -243,6 +245,14 @@ class Stree_test(unittest.TestCase): computed.append(str(node)) self.assertListEqual(expected, computed) + def test_is_a_sklearn_classifier(self): + import warnings + from sklearn.exceptions import ConvergenceWarning + warnings.filterwarnings('ignore', category=ConvergenceWarning) + warnings.filterwarnings('ignore', category=RuntimeWarning) + from sklearn.utils.estimator_checks import check_estimator + check_estimator(Stree()) + class Snode_test(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -291,7 +301,7 @@ class Snode_test(unittest.TestCase): class_computed = classes[card == max_card] self.assertEqual(class_computed, node._class) - check_leave(self._clf._tree) + check_leave(self._clf.tree_) def test_nodes_coefs(self): """Check if the nodes of the tree have the right attributes filled @@ -309,5 +319,4 @@ class Snode_test(unittest.TestCase): run_tree(node.get_down()) run_tree(node.get_up()) - run_tree(self._clf._tree) - + run_tree(self._clf.tree_) diff --git a/test2.ipynb b/test2.ipynb index 9e088c5..20316bc 100644 --- a/test2.ipynb +++ b/test2.ipynb @@ -48,7 +48,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.110% 494\nValid: 66.890% 998\n" + "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 32.976% 492\nValid: 67.024% 1000\n" } ], "source": [ @@ -94,12 +94,16 @@ { "cell_type": "code", "execution_count": 5, - "metadata": {}, + "metadata": { + "tags": [ + "outputPrepend" + ] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9521\nClassifier's accuracy (test) : 0.9598\nroot\nroot - Down, - Leaf class=1 belief=0.980519 counts=(array([0, 1]), array([ 6, 302]))\nroot - Up, - Leaf class=0 belief=0.940217 counts=(array([0, 1]), array([692, 44]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9521\nClassifier's accuracy (test) : 0.9643\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=0.986842 counts=(array([0, 1]), array([ 4, 300]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up, - Leaf class=0 belief=0.937754 counts=(array([0, 1]), array([693, 46]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9636\nClassifier's accuracy (test) : 0.9688\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([308]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([8]))\nroot - Up, - Leaf class=0 belief=0.947802 counts=(array([0, 1]), array([690, 38]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9665\nClassifier's accuracy (test) : 0.9621\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([308]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([11]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0 belief=0.951456 counts=(array([0, 1]), array([686, 35]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9741\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n\n**************************************************\n0.7816 secs\n" + "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9579\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down, - Leaf class=1 belief=0.987013 counts=(array([0, 1]), array([ 4, 304]))\nroot - Up, - Leaf class=0 belief=0.945652 counts=(array([0, 1]), array([696, 40]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9579\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down, - Leaf class=1 belief=0.990196 counts=(array([0, 1]), array([ 3, 303]))\nroot - Up, - Leaf class=0 belief=0.944444 counts=(array([0, 1]), array([697, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9693\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([311]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([6]))\nroot - Up\nroot - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up, - Leaf class=0 belief=0.955923 counts=(array([0, 1]), array([694, 32]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9713\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([314]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([6]))\nroot - Up, - Leaf class=0 belief=0.958564 counts=(array([0, 1]), array([694, 30]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9780\nClassifier's accuracy (test) : 0.9420\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([13]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([17]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0 belief=0.967376 counts=(array([0, 1]), array([682, 23]))\n\n**************************************************\n0.4537 secs\n" } ], "source": [ @@ -140,7 +144,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([13]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([17]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0 belief=0.967376 counts=(array([0, 1]), array([682, 23]))\n" } ], "source": [ @@ -157,7 +161,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([13]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([17]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0 belief=0.967376 counts=(array([0, 1]), array([682, 23]))\n" } ], "source": [ @@ -165,6 +169,38 @@ "for i in clf:\n", " print(i)" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if the classifier is a sklearn estimator\n", + "from sklearn.utils.estimator_checks import check_estimator\n", + "check_estimator(Stree())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "1 functools.partial(, 'Stree')\n2 functools.partial(, 'Stree')\n3 functools.partial(, 'Stree')\n4 functools.partial(, 'Stree')\n5 functools.partial(, 'Stree')\n6 functools.partial(, 'Stree')\n7 functools.partial(, 'Stree')\n8 functools.partial(, 'Stree')\n9 functools.partial(, 'Stree', readonly_memmap=True)\n10 functools.partial(, 'Stree')\n11 functools.partial(, 'Stree')\n12 functools.partial(, 'Stree')\n13 functools.partial(, 'Stree')\n14 functools.partial(, 'Stree')\n15 functools.partial(, 'Stree')\n16 functools.partial(, 'Stree')\n17 functools.partial(, 'Stree')\n18 functools.partial(, 'Stree')\n19 functools.partial(, 'Stree')\n20 functools.partial(, 'Stree')\n21 functools.partial(, 'Stree')\n22 functools.partial(, 'Stree')\n23 functools.partial(, 'Stree', readonly_memmap=True)\n24 functools.partial(, 'Stree')\n25 functools.partial(, 'Stree')\n26 functools.partial(, 'Stree')\n27 functools.partial(, 'Stree')\n28 functools.partial(, 'Stree')\n29 functools.partial(, 'Stree')\n30 functools.partial(, 'Stree')\n31 functools.partial(, 'Stree')\n32 functools.partial(, 'Stree')\n33 functools.partial(, 'Stree')\n34 functools.partial(, 'Stree')\n35 functools.partial(, 'Stree')\n36 functools.partial(, 'Stree')\n37 functools.partial(, 'Stree')\n38 functools.partial(, 'Stree')\n39 functools.partial(, 'Stree')\n" + } + ], + "source": [ + "# Make checks one by one\n", + "c = 0\n", + "checks = check_estimator(Stree(), generate_only=True)\n", + "for check in checks:\n", + " c += 1\n", + " print(c, check[1])\n", + " check[1](check[0])" + ] } ], "metadata": {