mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 16:36:01 +00:00
Refactor build_predictor
This commit is contained in:
@@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../stree/"))
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "STree"
|
project = "STree"
|
||||||
copyright = "2021, Ricardo Montañana Gómez"
|
copyright = "2020 - 2021, Ricardo Montañana Gómez"
|
||||||
author = "Ricardo Montañana Gómez"
|
author = "Ricardo Montañana Gómez"
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
|
@@ -655,7 +655,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.n_features_in_ = X.shape[1]
|
self.n_features_in_ = X.shape[1]
|
||||||
self.max_features_ = self._initialize_max_features()
|
self.max_features_ = self._initialize_max_features()
|
||||||
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
||||||
self._build_predictor()
|
|
||||||
self.X_ = X
|
self.X_ = X
|
||||||
self.y_ = y
|
self.y_ = y
|
||||||
return self
|
return self
|
||||||
@@ -703,6 +702,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
if np.unique(y).shape[0] == 1:
|
if np.unique(y).shape[0] == 1:
|
||||||
# only 1 class => pure dataset
|
# only 1 class => pure dataset
|
||||||
node.set_title(title + ", <pure>")
|
node.set_title(title + ", <pure>")
|
||||||
|
node.make_predictor()
|
||||||
return node
|
return node
|
||||||
# Train the model
|
# Train the model
|
||||||
clf = self._build_clf()
|
clf = self._build_clf()
|
||||||
@@ -721,6 +721,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
if X_U is None or X_D is None:
|
if X_U is None or X_D is None:
|
||||||
# didn't part anything
|
# didn't part anything
|
||||||
node.set_title(title + ", <cgaf>")
|
node.set_title(title + ", <cgaf>")
|
||||||
|
node.make_predictor()
|
||||||
return node
|
return node
|
||||||
node.set_up(
|
node.set_up(
|
||||||
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||||
@@ -732,18 +733,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _build_predictor(self):
|
|
||||||
"""Process the leaves to make them predictors"""
|
|
||||||
|
|
||||||
def run_tree(node: Snode):
|
|
||||||
if node.is_leaf():
|
|
||||||
node.make_predictor()
|
|
||||||
return
|
|
||||||
run_tree(node.get_down())
|
|
||||||
run_tree(node.get_up())
|
|
||||||
|
|
||||||
run_tree(self.tree_)
|
|
||||||
|
|
||||||
def _build_clf(self):
|
def _build_clf(self):
|
||||||
"""Build the correct classifier for the node"""
|
"""Build the correct classifier for the node"""
|
||||||
return (
|
return (
|
||||||
|
Reference in New Issue
Block a user