From 0f89b044f18c4ac4e91e4abd27b1a3f9df9e180e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 7 Apr 2021 01:02:30 +0200 Subject: [PATCH] Refactor train method --- stree/Strees.py | 52 ++++++++++++++++++++++++--------------- stree/tests/Snode_test.py | 25 +++++++++++++++++++ 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index c36f190..fa0cbd9 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -79,6 +79,30 @@ class Snode: def set_down(self, son): self._down = son + def set_title(self, title): + self._title = title + + def set_classifier(self, clf): + self._clf = clf + + def set_features(self, features): + self._features = features + + def set_impurity(self, impurity): + self._impurity = impurity + + def get_title(self) -> str: + return self._title + + def get_classifier(self) -> SVC: + return self._clf + + def get_impurity(self) -> float: + return self._impurity + + def get_features(self) -> np.array: + return self._features + def set_up(self, son): self._up = son @@ -636,38 +660,26 @@ class Stree(BaseEstimator, ClassifierMixin): y = y[~indices_zero] sample_weight = sample_weight[~indices_zero] self.depth_ = max(depth, self.depth_) + node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight) if np.unique(y).shape[0] == 1: # only 1 class => pure dataset - return Snode( - clf=None, - X=X, - y=y, - features=X.shape[1], - impurity=0.0, - title=title + ", ", - weight=sample_weight, - ) + node.set_title(title + ", ") + return node # Train the model clf = self._build_clf() Xs, features = self.splitter_.get_subspace(X, y, self.max_features_) clf.fit(Xs, y, sample_weight=sample_weight) - impurity = self.splitter_.partition_impurity(y) - node = Snode(clf, X, y, features, impurity, title, sample_weight) + node.set_impurity(self.splitter_.partition_impurity(y)) + node.set_classifier(clf) + node.set_features(features) self.splitter_.partition(X, node, True) X_U, X_D = self.splitter_.part(X) y_u, y_d = self.splitter_.part(y) sw_u, sw_d = self.splitter_.part(sample_weight) if X_U is None or X_D is None: # didn't part anything - return Snode( - clf, - X, - y, - features=X.shape[1], - impurity=impurity, - title=title + ", ", - weight=sample_weight, - ) + node.set_title(title + ", ") + return node node.set_up( self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})") ) diff --git a/stree/tests/Snode_test.py b/stree/tests/Snode_test.py index b32880a..b1e2728 100644 --- a/stree/tests/Snode_test.py +++ b/stree/tests/Snode_test.py @@ -69,6 +69,31 @@ class Snode_test(unittest.TestCase): self.assertEqual(0.75, test._belief) self.assertEqual(-1, test._partition_column) + def test_set_title(self): + test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") + self.assertEqual("test", test.get_title()) + test.set_title("another") + self.assertEqual("another", test.get_title()) + + def test_set_classifier(self): + test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") + clf = Stree() + self.assertIsNone(test.get_classifier()) + test.set_classifier(clf) + self.assertEqual(clf, test.get_classifier()) + + def test_set_impurity(self): + test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") + self.assertEqual(0.0, test.get_impurity()) + test.set_impurity(54.7) + self.assertEqual(54.7, test.get_impurity()) + + def test_set_features(self): + test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [0, 1], 0.0, "test") + self.assertListEqual([0, 1], test.get_features()) + test.set_features([1, 2]) + self.assertListEqual([1, 2], test.get_features()) + def test_make_predictor_on_not_leaf(self): test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))