mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
Refactor train method
This commit is contained in:
@@ -79,6 +79,30 @@ class Snode:
|
|||||||
def set_down(self, son):
|
def set_down(self, son):
|
||||||
self._down = son
|
self._down = son
|
||||||
|
|
||||||
|
def set_title(self, title):
|
||||||
|
self._title = title
|
||||||
|
|
||||||
|
def set_classifier(self, clf):
|
||||||
|
self._clf = clf
|
||||||
|
|
||||||
|
def set_features(self, features):
|
||||||
|
self._features = features
|
||||||
|
|
||||||
|
def set_impurity(self, impurity):
|
||||||
|
self._impurity = impurity
|
||||||
|
|
||||||
|
def get_title(self) -> str:
|
||||||
|
return self._title
|
||||||
|
|
||||||
|
def get_classifier(self) -> SVC:
|
||||||
|
return self._clf
|
||||||
|
|
||||||
|
def get_impurity(self) -> float:
|
||||||
|
return self._impurity
|
||||||
|
|
||||||
|
def get_features(self) -> np.array:
|
||||||
|
return self._features
|
||||||
|
|
||||||
def set_up(self, son):
|
def set_up(self, son):
|
||||||
self._up = son
|
self._up = son
|
||||||
|
|
||||||
@@ -636,38 +660,26 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
y = y[~indices_zero]
|
y = y[~indices_zero]
|
||||||
sample_weight = sample_weight[~indices_zero]
|
sample_weight = sample_weight[~indices_zero]
|
||||||
self.depth_ = max(depth, self.depth_)
|
self.depth_ = max(depth, self.depth_)
|
||||||
|
node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight)
|
||||||
if np.unique(y).shape[0] == 1:
|
if np.unique(y).shape[0] == 1:
|
||||||
# only 1 class => pure dataset
|
# only 1 class => pure dataset
|
||||||
return Snode(
|
node.set_title(title + ", <pure>")
|
||||||
clf=None,
|
return node
|
||||||
X=X,
|
|
||||||
y=y,
|
|
||||||
features=X.shape[1],
|
|
||||||
impurity=0.0,
|
|
||||||
title=title + ", <pure>",
|
|
||||||
weight=sample_weight,
|
|
||||||
)
|
|
||||||
# Train the model
|
# Train the model
|
||||||
clf = self._build_clf()
|
clf = self._build_clf()
|
||||||
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
|
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
|
||||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||||
impurity = self.splitter_.partition_impurity(y)
|
node.set_impurity(self.splitter_.partition_impurity(y))
|
||||||
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
node.set_classifier(clf)
|
||||||
|
node.set_features(features)
|
||||||
self.splitter_.partition(X, node, True)
|
self.splitter_.partition(X, node, True)
|
||||||
X_U, X_D = self.splitter_.part(X)
|
X_U, X_D = self.splitter_.part(X)
|
||||||
y_u, y_d = self.splitter_.part(y)
|
y_u, y_d = self.splitter_.part(y)
|
||||||
sw_u, sw_d = self.splitter_.part(sample_weight)
|
sw_u, sw_d = self.splitter_.part(sample_weight)
|
||||||
if X_U is None or X_D is None:
|
if X_U is None or X_D is None:
|
||||||
# didn't part anything
|
# didn't part anything
|
||||||
return Snode(
|
node.set_title(title + ", <cgaf>")
|
||||||
clf,
|
return node
|
||||||
X,
|
|
||||||
y,
|
|
||||||
features=X.shape[1],
|
|
||||||
impurity=impurity,
|
|
||||||
title=title + ", <cgaf>",
|
|
||||||
weight=sample_weight,
|
|
||||||
)
|
|
||||||
node.set_up(
|
node.set_up(
|
||||||
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||||
)
|
)
|
||||||
|
@@ -69,6 +69,31 @@ class Snode_test(unittest.TestCase):
|
|||||||
self.assertEqual(0.75, test._belief)
|
self.assertEqual(0.75, test._belief)
|
||||||
self.assertEqual(-1, test._partition_column)
|
self.assertEqual(-1, test._partition_column)
|
||||||
|
|
||||||
|
def test_set_title(self):
|
||||||
|
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||||
|
self.assertEqual("test", test.get_title())
|
||||||
|
test.set_title("another")
|
||||||
|
self.assertEqual("another", test.get_title())
|
||||||
|
|
||||||
|
def test_set_classifier(self):
|
||||||
|
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||||
|
clf = Stree()
|
||||||
|
self.assertIsNone(test.get_classifier())
|
||||||
|
test.set_classifier(clf)
|
||||||
|
self.assertEqual(clf, test.get_classifier())
|
||||||
|
|
||||||
|
def test_set_impurity(self):
|
||||||
|
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||||
|
self.assertEqual(0.0, test.get_impurity())
|
||||||
|
test.set_impurity(54.7)
|
||||||
|
self.assertEqual(54.7, test.get_impurity())
|
||||||
|
|
||||||
|
def test_set_features(self):
|
||||||
|
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [0, 1], 0.0, "test")
|
||||||
|
self.assertListEqual([0, 1], test.get_features())
|
||||||
|
test.set_features([1, 2])
|
||||||
|
self.assertListEqual([1, 2], test.get_features())
|
||||||
|
|
||||||
def test_make_predictor_on_not_leaf(self):
|
def test_make_predictor_on_not_leaf(self):
|
||||||
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||||
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
|
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
|
||||||
|
Reference in New Issue
Block a user