From 2d6921f9a51e5f52688fd74b12cb19ea32e3fa61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 19 Apr 2021 11:52:00 +0200 Subject: [PATCH] Refactor build_predictor --- docs/source/conf.py | 2 +- stree/Strees.py | 15 ++------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 639a13d..443436b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../stree/")) # -- Project information ----------------------------------------------------- project = "STree" -copyright = "2021, Ricardo Montañana Gómez" +copyright = "2020 - 2021, Ricardo Montañana Gómez" author = "Ricardo Montañana Gómez" # The full version, including alpha/beta/rc tags diff --git a/stree/Strees.py b/stree/Strees.py index 000b066..84a39e9 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -655,7 +655,6 @@ class Stree(BaseEstimator, ClassifierMixin): self.n_features_in_ = X.shape[1] self.max_features_ = self._initialize_max_features() self.tree_ = self.train(X, y, sample_weight, 1, "root") - self._build_predictor() self.X_ = X self.y_ = y return self @@ -703,6 +702,7 @@ class Stree(BaseEstimator, ClassifierMixin): if np.unique(y).shape[0] == 1: # only 1 class => pure dataset node.set_title(title + ", ") + node.make_predictor() return node # Train the model clf = self._build_clf() @@ -721,6 +721,7 @@ class Stree(BaseEstimator, ClassifierMixin): if X_U is None or X_D is None: # didn't part anything node.set_title(title + ", ") + node.make_predictor() return node node.set_up( self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})") @@ -732,18 +733,6 @@ class Stree(BaseEstimator, ClassifierMixin): ) return node - def _build_predictor(self): - """Process the leaves to make them predictors""" - - def run_tree(node: Snode): - if node.is_leaf(): - node.make_predictor() - return - run_tree(node.get_down()) - run_tree(node.get_up()) - - run_tree(self.tree_) - def _build_clf(self): """Build the correct classifier for the node""" return (