mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
Refactor build_predictor
This commit is contained in:
@@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../stree/"))
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "STree"
|
||||
copyright = "2021, Ricardo Montañana Gómez"
|
||||
copyright = "2020 - 2021, Ricardo Montañana Gómez"
|
||||
author = "Ricardo Montañana Gómez"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
|
@@ -655,7 +655,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.n_features_in_ = X.shape[1]
|
||||
self.max_features_ = self._initialize_max_features()
|
||||
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
||||
self._build_predictor()
|
||||
self.X_ = X
|
||||
self.y_ = y
|
||||
return self
|
||||
@@ -703,6 +702,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
node.set_title(title + ", <pure>")
|
||||
node.make_predictor()
|
||||
return node
|
||||
# Train the model
|
||||
clf = self._build_clf()
|
||||
@@ -721,6 +721,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
if X_U is None or X_D is None:
|
||||
# didn't part anything
|
||||
node.set_title(title + ", <cgaf>")
|
||||
node.make_predictor()
|
||||
return node
|
||||
node.set_up(
|
||||
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||
@@ -732,18 +733,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
)
|
||||
return node
|
||||
|
||||
def _build_predictor(self):
|
||||
"""Process the leaves to make them predictors"""
|
||||
|
||||
def run_tree(node: Snode):
|
||||
if node.is_leaf():
|
||||
node.make_predictor()
|
||||
return
|
||||
run_tree(node.get_down())
|
||||
run_tree(node.get_up())
|
||||
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
"""Build the correct classifier for the node"""
|
||||
return (
|
||||
|
Reference in New Issue
Block a user