Refactor train method

This commit is contained in:
2021-04-07 01:02:30 +02:00
parent 6ba973dfe1
commit 0f89b044f1
2 changed files with 57 additions and 20 deletions

View File

@@ -79,6 +79,30 @@ class Snode:
def set_down(self, son):
self._down = son
def set_title(self, title):
self._title = title
def set_classifier(self, clf):
self._clf = clf
def set_features(self, features):
self._features = features
def set_impurity(self, impurity):
self._impurity = impurity
def get_title(self) -> str:
return self._title
def get_classifier(self) -> SVC:
return self._clf
def get_impurity(self) -> float:
return self._impurity
def get_features(self) -> np.array:
return self._features
def set_up(self, son):
self._up = son
@@ -636,38 +660,26 @@ class Stree(BaseEstimator, ClassifierMixin):
y = y[~indices_zero]
sample_weight = sample_weight[~indices_zero]
self.depth_ = max(depth, self.depth_)
node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight)
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
return Snode(
clf=None,
X=X,
y=y,
features=X.shape[1],
impurity=0.0,
title=title + ", <pure>",
weight=sample_weight,
)
node.set_title(title + ", <pure>")
return node
# Train the model
clf = self._build_clf()
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
clf.fit(Xs, y, sample_weight=sample_weight)
impurity = self.splitter_.partition_impurity(y)
node = Snode(clf, X, y, features, impurity, title, sample_weight)
node.set_impurity(self.splitter_.partition_impurity(y))
node.set_classifier(clf)
node.set_features(features)
self.splitter_.partition(X, node, True)
X_U, X_D = self.splitter_.part(X)
y_u, y_d = self.splitter_.part(y)
sw_u, sw_d = self.splitter_.part(sample_weight)
if X_U is None or X_D is None:
# didn't part anything
return Snode(
clf,
X,
y,
features=X.shape[1],
impurity=impurity,
title=title + ", <cgaf>",
weight=sample_weight,
)
node.set_title(title + ", <cgaf>")
return node
node.set_up(
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
)

View File

@@ -69,6 +69,31 @@ class Snode_test(unittest.TestCase):
self.assertEqual(0.75, test._belief)
self.assertEqual(-1, test._partition_column)
def test_set_title(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
self.assertEqual("test", test.get_title())
test.set_title("another")
self.assertEqual("another", test.get_title())
def test_set_classifier(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
clf = Stree()
self.assertIsNone(test.get_classifier())
test.set_classifier(clf)
self.assertEqual(clf, test.get_classifier())
def test_set_impurity(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
self.assertEqual(0.0, test.get_impurity())
test.set_impurity(54.7)
self.assertEqual(54.7, test.get_impurity())
def test_set_features(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [0, 1], 0.0, "test")
self.assertListEqual([0, 1], test.get_features())
test.set_features([1, 2])
self.assertListEqual([1, 2], test.get_features())
def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))