mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Add a test to the tests set Add depth to node description Fix iterator and str test due to this addon
This commit is contained in:
committed by
GitHub
parent
f438124057
commit
460c63a6d0
@@ -635,6 +635,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
X = X[~indices_zero, :]
|
||||
y = y[~indices_zero]
|
||||
sample_weight = sample_weight[~indices_zero]
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
return Snode(
|
||||
@@ -652,7 +653,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||
impurity = self.splitter_.partition_impurity(y)
|
||||
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
self.splitter_.partition(X, node, True)
|
||||
X_U, X_D = self.splitter_.part(X)
|
||||
y_u, y_d = self.splitter_.part(y)
|
||||
@@ -668,8 +668,14 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
title=title + ", <cgaf>",
|
||||
weight=sample_weight,
|
||||
)
|
||||
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
|
||||
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
|
||||
node.set_up(
|
||||
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||
)
|
||||
node.set_down(
|
||||
self.train(
|
||||
X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})"
|
||||
)
|
||||
)
|
||||
return node
|
||||
|
||||
def _build_predictor(self):
|
||||
|
@@ -103,20 +103,20 @@ class Stree_test(unittest.TestCase):
|
||||
def test_iterator_and_str(self):
|
||||
"""Check preorder iterator"""
|
||||
expected = [
|
||||
"root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), arr"
|
||||
"ay([750, 750]))",
|
||||
"root - Down, <cgaf> - Leaf class=0 belief= 0.928297 impurity=0.37"
|
||||
"22 counts=(array([0, 1]), array([725, 56]))",
|
||||
"root - Up feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, 1])"
|
||||
", array([ 25, 694]))",
|
||||
"root - Up - Down feaures=(0, 1, 2) impurity=0.8454 counts=(array("
|
||||
"[0, 1]), array([8, 3]))",
|
||||
"root - Up - Down - Down, <pure> - Leaf class=0 belief= 1.000000 i"
|
||||
"mpurity=0.0000 counts=(array([0]), array([7]))",
|
||||
"root - Up - Down - Up, <cgaf> - Leaf class=1 belief= 0.750000 imp"
|
||||
"urity=0.8113 counts=(array([0, 1]), array([1, 3]))",
|
||||
"root - Up - Up, <cgaf> - Leaf class=1 belief= 0.975989 impurity=0"
|
||||
".1634 counts=(array([0, 1]), array([ 17, 691]))",
|
||||
"root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), "
|
||||
"array([750, 750]))",
|
||||
"root - Down(2), <cgaf> - Leaf class=0 belief= 0.928297 impurity="
|
||||
"0.3722 counts=(array([0, 1]), array([725, 56]))",
|
||||
"root - Up(2) feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, "
|
||||
"1]), array([ 25, 694]))",
|
||||
"root - Up(2) - Down(3) feaures=(0, 1, 2) impurity=0.8454 counts="
|
||||
"(array([0, 1]), array([8, 3]))",
|
||||
"root - Up(2) - Down(3) - Down(4), <pure> - Leaf class=0 belief= "
|
||||
"1.000000 impurity=0.0000 counts=(array([0]), array([7]))",
|
||||
"root - Up(2) - Down(3) - Up(4), <cgaf> - Leaf class=1 belief= "
|
||||
"0.750000 impurity=0.8113 counts=(array([0, 1]), array([1, 3]))",
|
||||
"root - Up(2) - Up(3), <cgaf> - Leaf class=1 belief= 0.975989 "
|
||||
"impurity=0.1634 counts=(array([0, 1]), array([ 17, 691]))",
|
||||
]
|
||||
computed = []
|
||||
expected_string = ""
|
||||
@@ -439,3 +439,18 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertEqual(model1.score(X, y), 1)
|
||||
self.assertAlmostEqual(model2.score(X, y), 0.66666667)
|
||||
self.assertEqual(model2.score(X, y, w), 1)
|
||||
|
||||
def test_depth(self):
|
||||
X, y = load_dataset(
|
||||
random_state=self._random_state,
|
||||
n_classes=3,
|
||||
n_features=5,
|
||||
n_samples=1500,
|
||||
)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
self.assertEqual(6, clf.depth_)
|
||||
X, y = load_wine(return_X_y=True)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
self.assertEqual(7, clf.depth_)
|
||||
|
Reference in New Issue
Block a user