Refactor build_predictor

This commit is contained in:
2021-04-19 11:52:00 +02:00
parent 9eb06a9169
commit 2d6921f9a5
2 changed files with 3 additions and 14 deletions

View File

@@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../stree/"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "STree" project = "STree"
copyright = "2021, Ricardo Montañana Gómez" copyright = "2020 - 2021, Ricardo Montañana Gómez"
author = "Ricardo Montañana Gómez" author = "Ricardo Montañana Gómez"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags

View File

@@ -655,7 +655,6 @@ class Stree(BaseEstimator, ClassifierMixin):
self.n_features_in_ = X.shape[1] self.n_features_in_ = X.shape[1]
self.max_features_ = self._initialize_max_features() self.max_features_ = self._initialize_max_features()
self.tree_ = self.train(X, y, sample_weight, 1, "root") self.tree_ = self.train(X, y, sample_weight, 1, "root")
self._build_predictor()
self.X_ = X self.X_ = X
self.y_ = y self.y_ = y
return self return self
@@ -703,6 +702,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if np.unique(y).shape[0] == 1: if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset # only 1 class => pure dataset
node.set_title(title + ", <pure>") node.set_title(title + ", <pure>")
node.make_predictor()
return node return node
# Train the model # Train the model
clf = self._build_clf() clf = self._build_clf()
@@ -721,6 +721,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if X_U is None or X_D is None: if X_U is None or X_D is None:
# didn't part anything # didn't part anything
node.set_title(title + ", <cgaf>") node.set_title(title + ", <cgaf>")
node.make_predictor()
return node return node
node.set_up( node.set_up(
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})") self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
@@ -732,18 +733,6 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
return node return node
def _build_predictor(self):
"""Process the leaves to make them predictors"""
def run_tree(node: Snode):
if node.is_leaf():
node.make_predictor()
return
run_tree(node.get_down())
run_tree(node.get_up())
run_tree(self.tree_)
def _build_clf(self): def _build_clf(self):
"""Build the correct classifier for the node""" """Build the correct classifier for the node"""
return ( return (