mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 16:36:01 +00:00
Add a test to the tests set Add depth to node description Fix iterator and str test due to this addon
This commit is contained in:
committed by
GitHub
parent
f438124057
commit
460c63a6d0
@@ -635,6 +635,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
X = X[~indices_zero, :]
|
X = X[~indices_zero, :]
|
||||||
y = y[~indices_zero]
|
y = y[~indices_zero]
|
||||||
sample_weight = sample_weight[~indices_zero]
|
sample_weight = sample_weight[~indices_zero]
|
||||||
|
self.depth_ = max(depth, self.depth_)
|
||||||
if np.unique(y).shape[0] == 1:
|
if np.unique(y).shape[0] == 1:
|
||||||
# only 1 class => pure dataset
|
# only 1 class => pure dataset
|
||||||
return Snode(
|
return Snode(
|
||||||
@@ -652,7 +653,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||||
impurity = self.splitter_.partition_impurity(y)
|
impurity = self.splitter_.partition_impurity(y)
|
||||||
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
||||||
self.depth_ = max(depth, self.depth_)
|
|
||||||
self.splitter_.partition(X, node, True)
|
self.splitter_.partition(X, node, True)
|
||||||
X_U, X_D = self.splitter_.part(X)
|
X_U, X_D = self.splitter_.part(X)
|
||||||
y_u, y_d = self.splitter_.part(y)
|
y_u, y_d = self.splitter_.part(y)
|
||||||
@@ -668,8 +668,14 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
title=title + ", <cgaf>",
|
title=title + ", <cgaf>",
|
||||||
weight=sample_weight,
|
weight=sample_weight,
|
||||||
)
|
)
|
||||||
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
|
node.set_up(
|
||||||
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
|
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||||
|
)
|
||||||
|
node.set_down(
|
||||||
|
self.train(
|
||||||
|
X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})"
|
||||||
|
)
|
||||||
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def _build_predictor(self):
|
def _build_predictor(self):
|
||||||
|
@@ -103,20 +103,20 @@ class Stree_test(unittest.TestCase):
|
|||||||
def test_iterator_and_str(self):
|
def test_iterator_and_str(self):
|
||||||
"""Check preorder iterator"""
|
"""Check preorder iterator"""
|
||||||
expected = [
|
expected = [
|
||||||
"root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), arr"
|
"root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), "
|
||||||
"ay([750, 750]))",
|
"array([750, 750]))",
|
||||||
"root - Down, <cgaf> - Leaf class=0 belief= 0.928297 impurity=0.37"
|
"root - Down(2), <cgaf> - Leaf class=0 belief= 0.928297 impurity="
|
||||||
"22 counts=(array([0, 1]), array([725, 56]))",
|
"0.3722 counts=(array([0, 1]), array([725, 56]))",
|
||||||
"root - Up feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, 1])"
|
"root - Up(2) feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, "
|
||||||
", array([ 25, 694]))",
|
"1]), array([ 25, 694]))",
|
||||||
"root - Up - Down feaures=(0, 1, 2) impurity=0.8454 counts=(array("
|
"root - Up(2) - Down(3) feaures=(0, 1, 2) impurity=0.8454 counts="
|
||||||
"[0, 1]), array([8, 3]))",
|
"(array([0, 1]), array([8, 3]))",
|
||||||
"root - Up - Down - Down, <pure> - Leaf class=0 belief= 1.000000 i"
|
"root - Up(2) - Down(3) - Down(4), <pure> - Leaf class=0 belief= "
|
||||||
"mpurity=0.0000 counts=(array([0]), array([7]))",
|
"1.000000 impurity=0.0000 counts=(array([0]), array([7]))",
|
||||||
"root - Up - Down - Up, <cgaf> - Leaf class=1 belief= 0.750000 imp"
|
"root - Up(2) - Down(3) - Up(4), <cgaf> - Leaf class=1 belief= "
|
||||||
"urity=0.8113 counts=(array([0, 1]), array([1, 3]))",
|
"0.750000 impurity=0.8113 counts=(array([0, 1]), array([1, 3]))",
|
||||||
"root - Up - Up, <cgaf> - Leaf class=1 belief= 0.975989 impurity=0"
|
"root - Up(2) - Up(3), <cgaf> - Leaf class=1 belief= 0.975989 "
|
||||||
".1634 counts=(array([0, 1]), array([ 17, 691]))",
|
"impurity=0.1634 counts=(array([0, 1]), array([ 17, 691]))",
|
||||||
]
|
]
|
||||||
computed = []
|
computed = []
|
||||||
expected_string = ""
|
expected_string = ""
|
||||||
@@ -439,3 +439,18 @@ class Stree_test(unittest.TestCase):
|
|||||||
self.assertEqual(model1.score(X, y), 1)
|
self.assertEqual(model1.score(X, y), 1)
|
||||||
self.assertAlmostEqual(model2.score(X, y), 0.66666667)
|
self.assertAlmostEqual(model2.score(X, y), 0.66666667)
|
||||||
self.assertEqual(model2.score(X, y, w), 1)
|
self.assertEqual(model2.score(X, y, w), 1)
|
||||||
|
|
||||||
|
def test_depth(self):
|
||||||
|
X, y = load_dataset(
|
||||||
|
random_state=self._random_state,
|
||||||
|
n_classes=3,
|
||||||
|
n_features=5,
|
||||||
|
n_samples=1500,
|
||||||
|
)
|
||||||
|
clf = Stree(random_state=self._random_state)
|
||||||
|
clf.fit(X, y)
|
||||||
|
self.assertEqual(6, clf.depth_)
|
||||||
|
X, y = load_wine(return_X_y=True)
|
||||||
|
clf = Stree(random_state=self._random_state)
|
||||||
|
clf.fit(X, y)
|
||||||
|
self.assertEqual(7, clf.depth_)
|
||||||
|
Reference in New Issue
Block a user