diff --git a/stree/Strees.py b/stree/Strees.py index c67308e..5464768 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -7,9 +7,9 @@ Build an oblique tree classifier based on SVM Trees """ import os -import numbers import random import warnings +from typing import Optional, List, Union, Tuple from math import log from itertools import combinations import numpy as np @@ -78,10 +78,10 @@ class Snode: def is_leaf(self) -> bool: return self._up is None and self._down is None - def get_down(self) -> "Snode": + def get_down(self) -> Optional["Snode"]: return self._down - def get_up(self) -> "Snode": + def get_up(self) -> Optional["Snode"]: return self._up def make_predictor(self): @@ -123,11 +123,11 @@ class Siterator: """Stree preorder iterator """ - def __init__(self, tree: Snode): - self._stack = [] + def __init__(self, tree: Optional[Snode]): + self._stack: List[Snode] = [] self._push(tree) - def _push(self, node: Snode): + def _push(self, node: Optional[Snode]) -> None: if node is not None: self._stack.append(node) @@ -150,7 +150,7 @@ class Splitter: min_samples_split: int = None, random_state=None, ): - self._clf = clf + self._clf: Union[SVC, LinearSVC] = clf self._random_state = random_state if random_state is not None: random.seed(random_state) @@ -230,8 +230,8 @@ class Splitter: def _select_best_set( self, dataset: np.array, labels: np.array, features_sets: list ) -> list: - max_gain = 0 - selected = None + max_gain: float = 0.0 + selected: Union[List[int], None] = None warnings.filterwarnings("ignore", category=ConvergenceWarning) for feature_set in features_sets: self._clf.fit(dataset[:, feature_set], labels) @@ -265,7 +265,7 @@ class Splitter: def get_subspace( self, dataset: np.array, labels: np.array, max_features: int - ) -> list: + ) -> Tuple[np.array, np.array]: """Return the best subspace to make a split """ indices = self._get_subspaces_set(dataset, labels, max_features) @@ -478,7 +478,7 @@ class Stree(BaseEstimator, ClassifierMixin): sample_weight: np.ndarray, depth: int, title: str, - ) -> Snode: + ) -> Optional[Snode]: """Recursive function to split the original dataset into predictor nodes (leaves) @@ -543,11 +543,13 @@ class Stree(BaseEstimator, ClassifierMixin): node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down")) return node - def _build_predictor(self): + def _build_predictor(self) -> None: """Process the leaves to make them predictors """ - def run_tree(node: Snode): + def run_tree(node: Optional[Snode]) -> None: + if node is None: + raise ValueError("Can't build predictors on None") if node.is_leaf(): node.make_predictor() return @@ -556,7 +558,7 @@ class Stree(BaseEstimator, ClassifierMixin): run_tree(self.tree_) - def _build_clf(self): + def _build_clf(self) -> Union[LinearSVC, SVC]: """ Build the correct classifier for the node """ return ( @@ -605,7 +607,7 @@ class Stree(BaseEstimator, ClassifierMixin): """ def predict_class( - xp: np.array, indices: np.array, node: Snode + xp: np.array, indices: np.array, node: Optional[Snode] ) -> np.array: if xp is None: return [], [] @@ -704,7 +706,7 @@ class Stree(BaseEstimator, ClassifierMixin): ) elif self.max_features is None: max_features = self.n_features_ - elif isinstance(self.max_features, numbers.Integral): + elif isinstance(self.max_features, int): max_features = self.max_features else: # float if self.max_features > 0.0: diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index e16a69f..e4715a6 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -414,3 +414,24 @@ class Stree_test(unittest.TestCase): # zero weights are ok when they don't erase a class _ = clf.train(X, y, weights_no_zero, 1, "test") self.assertListEqual(weights_no_zero.tolist(), original.tolist()) + + def test_build_predictor(self): + X, y = load_dataset(self._random_state) + clf = Stree(random_state=self._random_state) + with self.assertRaises(ValueError): + clf.tree_ = None + clf._build_predictor() + clf.fit(X, y) + node = clf.tree_.get_down().get_down() + expected_impurity = 0.04686951386893923 + expected_class = 1 + expected_belief = 0.9759887005649718 + self.assertAlmostEqual(expected_impurity, node._impurity) + self.assertAlmostEqual(expected_belief, node._belief) + self.assertEqual(expected_class, node._class) + node._belief = 0.0 + node._class = None + clf._build_predictor() + node = clf.tree_.get_down().get_down() + self.assertAlmostEqual(expected_belief, node._belief) + self.assertEqual(expected_class, node._class)