From c593b55bec2d8c0ff6cb5dd57be6669321a06964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sat, 17 Oct 2020 16:56:15 +0200 Subject: [PATCH] Complete implementation of splitter_type = impurity with tests Remove max_distance & min_distance splitter types --- stree/Strees.py | 40 +++++++++++++++++++++++++----------- stree/tests/Snode_test.py | 16 +++++++++++---- stree/tests/Splitter_test.py | 26 +++++++++++++++-------- stree/tests/Stree_test.py | 20 +++++++++--------- 4 files changed, 67 insertions(+), 35 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index d296647..8338ba4 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -57,6 +57,7 @@ class Snode: ) self._features = features self._impurity = impurity + self._partition_column: int = -1 @classmethod def copy(cls, node: "Snode") -> "Snode": @@ -69,6 +70,12 @@ class Snode: node._title, ) + def set_partition_column(self, col: int): + self._partition_column = col + + def get_partition_column(self) -> int: + return self._partition_column + def set_down(self, son): self._down = son @@ -239,7 +246,7 @@ class Splitter: node = Snode( self._clf, dataset, labels, feature_set, 0.0, "subset" ) - self.partition(dataset, node) + self.partition(dataset, node, train=True) y1, y2 = self.part(labels) gain = self.information_gain(labels, y1, y2) if gain > max_gain: @@ -272,7 +279,7 @@ class Splitter: return dataset[:, indices], indices def _impurity(self, data: np.array, _) -> np.array: - """return distances of the class whose partition has less impurity + """return column of dataset to be taken into account to split dataset :param data: distances to hyper plane of every class :type data: np.array (m, n_classes) @@ -289,15 +296,15 @@ class Splitter: y[data > 0] = 1 y = y.astype(int) for col in range(data.shape[1]): - impurity_of_class = self.partition_impurity(y[col]) + impurity_of_class = self.partition_impurity(y[:, col]) if impurity_of_class < min_impurity: selected = col min_impurity = impurity_of_class - return data[:, selected] + return selected @staticmethod def _max_samples(data: np.array, y: np.array) -> np.array: - """return distances of the class with more samples + """return column of dataset to be taken into account to split dataset :param data: distances to hyper plane of every class :type data: np.array (m, n_classes) @@ -308,10 +315,9 @@ class Splitter: """ # select the class with max number of samples _, samples = np.unique(y, return_counts=True) - selected = np.argmax(samples) - return data[:, selected] + return np.argmax(samples) - def partition(self, samples: np.array, node: Snode): + def partition(self, samples: np.array, node: Snode, train: bool): """Set the criteria to split arrays. Compute the indices of the samples that should go to one side of the tree (down) @@ -325,7 +331,16 @@ class Splitter: if data.ndim > 1: # split criteria for multiclass # Convert data to a (m, 1) array selecting values for samples - data = self.decision_criteria(data, node._y) + if train: + # in train time we have to compute the column to take into + # account to split the dataset + col = self.decision_criteria(data, node._y) + node.set_partition_column(col) + else: + # in predcit time just use the column computed in train time + # is taking the classifier of class + col = node.get_partition_column() + data = data[:, col] self._down = data > 0 @staticmethod @@ -344,6 +359,7 @@ class Splitter: def part(self, origin: np.array) -> list: """Split an array in two based on indices (down) and its complement + partition has to be called first to establish down indices :param origin: dataset to split :type origin: np.array @@ -377,7 +393,7 @@ class Stree(BaseEstimator, ClassifierMixin): tol: float = 1e-4, degree: int = 3, gamma="scale", - split_criteria: str = "max_samples", + split_criteria: str = "impurity", criterion: str = "gini", min_samples_split: int = 0, max_features=None, @@ -518,7 +534,7 @@ class Stree(BaseEstimator, ClassifierMixin): impurity = self.splitter_.partition_impurity(y) node = Snode(clf, X, y, features, impurity, title, sample_weight) self.depth_ = max(depth, self.depth_) - self.splitter_.partition(X, node) + self.splitter_.partition(X, node, True) X_U, X_D = self.splitter_.part(X) y_u, y_d = self.splitter_.part(y) sw_u, sw_d = self.splitter_.part(sample_weight) @@ -605,7 +621,7 @@ class Stree(BaseEstimator, ClassifierMixin): # set a class for every sample in dataset prediction = np.full((xp.shape[0], 1), node._class) return prediction, indices - self.splitter_.partition(xp, node) + self.splitter_.partition(xp, node, train=False) x_u, x_d = self.splitter_.part(xp) i_u, i_d = self.splitter_.part(indices) prx_u, prin_u = predict_class(x_u, i_u, node.get_up()) diff --git a/stree/tests/Snode_test.py b/stree/tests/Snode_test.py index 27e5d0a..b32880a 100644 --- a/stree/tests/Snode_test.py +++ b/stree/tests/Snode_test.py @@ -40,12 +40,13 @@ class Snode_test(unittest.TestCase): # Check Class class_computed = classes[card == max_card] self.assertEqual(class_computed, node._class) + # Check Partition column + self.assertEqual(node._partition_column, -1) check_leave(self._clf.tree_) def test_nodes_coefs(self): - """Check if the nodes of the tree have the right attributes filled - """ + """Check if the nodes of the tree have the right attributes filled""" def run_tree(node: Snode): if node._belief < 1: @@ -54,16 +55,19 @@ class Snode_test(unittest.TestCase): self.assertIsNotNone(node._clf.coef_) if node.is_leaf(): return - run_tree(node.get_down()) run_tree(node.get_up()) + run_tree(node.get_down()) - run_tree(self._clf.tree_) + model = Stree(self._random_state) + model.fit(*load_dataset(self._random_state, 3, 4)) + run_tree(model.tree_) def test_make_predictor_on_leaf(self): test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") test.make_predictor() self.assertEqual(1, test._class) self.assertEqual(0.75, test._belief) + self.assertEqual(-1, test._partition_column) def test_make_predictor_on_not_leaf(self): test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") @@ -71,11 +75,14 @@ class Snode_test(unittest.TestCase): test.make_predictor() self.assertIsNone(test._class) self.assertEqual(0, test._belief) + self.assertEqual(-1, test._partition_column) + self.assertEqual(-1, test.get_up()._partition_column) def test_make_predictor_on_leaf_bogus_data(self): test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test") test.make_predictor() self.assertIsNone(test._class) + self.assertEqual(-1, test._partition_column) def test_copy_node(self): px = [1, 2, 3, 4] @@ -86,3 +93,4 @@ class Snode_test(unittest.TestCase): self.assertListEqual(computed._y, py) self.assertEqual("test", computed._title) self.assertIsInstance(computed._clf, Stree) + self.assertEqual(test._partition_column, computed._partition_column) diff --git a/stree/tests/Splitter_test.py b/stree/tests/Splitter_test.py index 4e55bc6..a4eb3e3 100644 --- a/stree/tests/Splitter_test.py +++ b/stree/tests/Splitter_test.py @@ -134,13 +134,17 @@ class Splitter_test(unittest.TestCase): [0.7, 0.01, -0.1], [0.7, -0.9, 0.5], [0.1, 0.2, 0.3], + [-0.1, 0.2, 0.3], + [-0.1, 0.2, 0.3], ] ) - expected = np.array([0.2, 0.01, -0.9, 0.2]) - y = [1, 2, 1, 0] + expected = np.array([-0.1, 0.7, 0.7, 0.1, -0.1, -0.1]) + y = [1, 2, 1, 0, 0, 0] computed = tcl._max_samples(data, y) - self.assertEqual((4,), computed.shape) - self.assertListEqual(expected.tolist(), computed.tolist()) + self.assertEqual(0, computed) + computed_data = data[:, computed] + self.assertEqual((6,), computed_data.shape) + self.assertListEqual(expected.tolist(), computed_data.tolist()) def test_impurity(self): tcl = self.build(criteria="impurity") @@ -150,12 +154,16 @@ class Splitter_test(unittest.TestCase): [0.7, 0.01, -0.1], [0.7, -0.9, 0.5], [0.1, 0.2, 0.3], + [-0.1, 0.2, 0.3], + [-0.1, 0.2, 0.3], ] ) - expected = np.array([-0.1, 0.7, 0.7, 0.1]) + expected = np.array([0.2, 0.01, -0.9, 0.2, 0.2, 0.2]) computed = tcl._impurity(data, None) - self.assertEqual((4,), computed.shape) - self.assertListEqual(expected.tolist(), computed.tolist()) + self.assertEqual(1, computed) + computed_data = data[:, computed] + self.assertEqual((6,), computed_data.shape) + self.assertListEqual(expected.tolist(), computed_data.tolist()) def test_best_splitter_few_sets(self): X, y = load_iris(return_X_y=True) @@ -168,9 +176,9 @@ class Splitter_test(unittest.TestCase): def test_splitter_parameter(self): expected_values = [ [0, 1, 7, 9], # best entropy max_samples - [3, 8, 10, 11], # best entropy impurity + [0, 2, 4, 5], # best entropy impurity [0, 2, 8, 12], # best gini max_samples - [1, 2, 5, 12], # best gini impurity + [4, 5, 9, 12], # best gini impurity [1, 2, 5, 10], # random entropy max_samples [4, 8, 9, 12], # random entropy impurity [3, 9, 11, 12], # random gini max_samples diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index c10f0f9..73ba74e 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -217,7 +217,7 @@ class Stree_test(unittest.TestCase): random_state=self._random_state, ) clf.fit(px, py) - print(f"{name} {criteria} {kernel}") + # print(f"{name} {criteria} {kernel}") outcome = outcomes[name][f"{criteria} {kernel}"] self.assertAlmostEqual(outcome, clf.score(px, py)) @@ -310,23 +310,23 @@ class Stree_test(unittest.TestCase): def test_score_multi_class(self): warnings.filterwarnings("ignore") accuracies = [ - 0.651685393258427, # Wine linear impurity + 0.7022472, # Wine linear impurity 0.8314607, # Wine linear max_samples - 0.6629213483146067, # Wine rbf impurity + 0.4044944, # Wine rbf impurity 0.4044944, # Wine rbf max_samples - 0.9157303, # Wine poly impurity + 0.3988764, # Wine poly impurity 0.7640449, # Wine poly max_samples - 0.9933333, # Iris linear impurity + 0.6600000, # Iris linear impurity 0.9666667, # Iris linear max_samples - 0.9800000, # Iris rbf impurity + 0.3333333, # Iris rbf impurity 0.9800000, # Iris rbf max_samples - 1.0000000, # Iris poly impurity + 0.3333333, # Iris poly impurity 1.0000000, # Iris poly max_samples - 0.8993333, # Synthetic linear impurity + 0.7153333, # Synthetic linear impurity 0.9313333, # Synthetic linear max_samples - 0.8320000, # Synthetic rbf impurity + 0.4806667, # Synthetic rbf impurity 0.8320000, # Synthetic rbf max_samples - 0.6066667, # Synthetic poly impurity + 0.4786667, # Synthetic poly impurity 0.6340000, # Synthetic poly max_samples ] datasets = [