mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Implement predict and score methods & tests
This commit is contained in:
@@ -11,9 +11,9 @@ class Snode_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._random_state = 1
|
||||
self._model = Stree(random_state=self._random_state,
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._model.fit(*self._get_Xy())
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Snode_test, self).__init__(*args, **kwargs)
|
||||
|
||||
def _get_Xy(self):
|
||||
@@ -42,4 +42,4 @@ class Snode_test(unittest.TestCase):
|
||||
return
|
||||
check_leave(node.get_down())
|
||||
check_leave(node.get_up())
|
||||
check_leave(self._model._tree)
|
||||
check_leave(self._clf._tree)
|
||||
|
@@ -11,9 +11,9 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._random_state = 1
|
||||
self._model = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._model.fit(*self._get_Xy())
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=False)
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Stree_test, self).__init__(*args, **kwargs)
|
||||
|
||||
def _get_Xy(self):
|
||||
@@ -25,7 +25,7 @@ class Stree_test(unittest.TestCase):
|
||||
def _check_tree(self, node: Snode):
|
||||
if node.is_leaf():
|
||||
return
|
||||
y_prediction = node._model.predict(node._X)
|
||||
y_prediction = node._clf.predict(node._X)
|
||||
y_down = node.get_down()._y
|
||||
y_up = node.get_up()._y
|
||||
# Is a correct partition in terms of cadinality?
|
||||
@@ -55,7 +55,7 @@ class Stree_test(unittest.TestCase):
|
||||
def test_build_tree(self):
|
||||
"""Check if the tree is built the same way as predictions of models
|
||||
"""
|
||||
self._check_tree(self._model._tree)
|
||||
self._check_tree(self._clf._tree)
|
||||
|
||||
def _get_file_data(self, file_name: str) -> tuple:
|
||||
"""Return X, y from data, y is the last column in array
|
||||
@@ -94,14 +94,32 @@ class Stree_test(unittest.TestCase):
|
||||
def test_subdatasets(self):
|
||||
"""Check if the subdatasets files have the same predictions as the tree itself
|
||||
"""
|
||||
model = self._model._tree._model
|
||||
model = self._clf._tree._clf
|
||||
X, y = self._get_Xy()
|
||||
model.fit(X, y)
|
||||
self._model.save_sub_datasets()
|
||||
with open(self._model.get_catalog_name()) as cat_file:
|
||||
self._clf.save_sub_datasets()
|
||||
with open(self._clf.get_catalog_name()) as cat_file:
|
||||
catalog = csv.reader(cat_file, delimiter=',')
|
||||
for row in catalog:
|
||||
X, y = self._get_Xy()
|
||||
x_file, y_file = self._get_file_data(row[0])
|
||||
y_original = np.array(self._find_out(x_file, X, y), dtype=int)
|
||||
self.assertTrue(np.array_equal(y_file, y_original))
|
||||
|
||||
def test_single_prediction(self):
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict((X[0, :].reshape(-1, X.shape[1])))
|
||||
self.assertEqual(yp[0], y[0])
|
||||
|
||||
def test_multiple_prediction(self):
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict(X[:23, :])
|
||||
self.assertListEqual(y[:23].tolist(), yp.tolist())
|
||||
|
||||
def test_score(self):
|
||||
X, y = self._get_Xy()
|
||||
accuracy_score = self._clf.score(X, y, print_out=False)
|
||||
yp = self._clf.predict(X)
|
||||
right = (yp == y).astype(int)
|
||||
accuracy_computed = sum(right) / len(y)
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
|
Reference in New Issue
Block a user