Add complete classes counts to node and tests

This commit is contained in:
2022-05-31 01:21:03 +02:00
parent 93be8a89a8
commit 65923af9b4
4 changed files with 32 additions and 29 deletions

View File

@@ -68,6 +68,7 @@ class Snode:
self._impurity = impurity
self._partition_column: int = -1
self._scaler = scaler
self._proba = None
@classmethod
def copy(cls, node: "Snode") -> "Snode":
@@ -127,22 +128,21 @@ class Snode:
def get_up(self) -> "Snode":
return self._up
def make_predictor(self):
def make_predictor(self, num_classes: int) -> None:
"""Compute the class of the predictor and its belief based on the
subdataset of the node only if it is a leaf
"""
if not self.is_leaf():
return
classes, card = np.unique(self._y, return_counts=True)
if len(classes) > 1:
self._proba = np.zeros((num_classes,))
for c, n in zip(classes, card):
self._proba[c] = n
try:
max_card = max(card)
self._class = classes[card == max_card][0]
self._belief = max_card / np.sum(card)
else:
self._belief = 1
try:
self._class = classes[0]
except IndexError:
except ValueError:
self._class = None
def graph(self):
@@ -155,7 +155,7 @@ class Snode:
output += (
f'N{id(self)} [shape=box style=filled label="'
f"class={self._class} impurity={self._impurity:.3f} "
f'classes={count_values[0]} samples={count_values[1]}"];\n'
f'counts={self._proba}"];\n'
)
else:
output += (

View File

@@ -314,7 +314,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
node.set_title(title + ", <pure>")
node.make_predictor()
node.make_predictor(self.n_classes_)
return node
# Train the model
clf = self._build_clf()
@@ -333,7 +333,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if X_U is None or X_D is None:
# didn't part anything
node.set_title(title + ", <cgaf>")
node.make_predictor()
node.make_predictor(self.n_classes_)
return node
node.set_up(
self._train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")

View File

@@ -67,10 +67,28 @@ class Snode_test(unittest.TestCase):
def test_make_predictor_on_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.make_predictor()
test.make_predictor(2)
self.assertEqual(1, test._class)
self.assertEqual(0.75, test._belief)
self.assertEqual(-1, test._partition_column)
self.assertListEqual([1, 3], test._proba.tolist())
def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
test.make_predictor(2)
self.assertIsNone(test._class)
self.assertEqual(0, test._belief)
self.assertEqual(-1, test._partition_column)
self.assertEqual(-1, test.get_up()._partition_column)
self.assertIsNone(test._proba)
def test_make_predictor_on_leaf_bogus_data(self):
test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test")
test.make_predictor(2)
self.assertIsNone(test._class)
self.assertEqual(-1, test._partition_column)
self.assertListEqual([0, 0], test._proba.tolist())
def test_set_title(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
@@ -97,21 +115,6 @@ class Snode_test(unittest.TestCase):
test.set_features([1, 2])
self.assertListEqual([1, 2], test.get_features())
def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
test.make_predictor()
self.assertIsNone(test._class)
self.assertEqual(0, test._belief)
self.assertEqual(-1, test._partition_column)
self.assertEqual(-1, test.get_up()._partition_column)
def test_make_predictor_on_leaf_bogus_data(self):
test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test")
test.make_predictor()
self.assertIsNone(test._class)
self.assertEqual(-1, test._partition_column)
def test_copy_node(self):
px = [1, 2, 3, 4]
py = [1]

View File

@@ -695,7 +695,7 @@ class Stree_test(unittest.TestCase):
)
expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n'
'counts=[0. 1. 0.]"];\n}\n'
)
self.assertEqual(clf.graph(), expected_head + "}\n")
clf.fit(X, y)
@@ -715,7 +715,7 @@ class Stree_test(unittest.TestCase):
)
expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n'
'counts=[0. 1. 0.]"];\n}\n'
)
self.assertEqual(clf.graph("Sample title"), expected_head + "}\n")
clf.fit(X, y)