mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Implement predict_proba with test.
Fix tree overload with dataset in nodes only needed in tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from sklearn.datasets import make_classification
|
||||
import os
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
@@ -10,12 +11,20 @@ from trees.Stree import Stree, Snode
|
||||
class Snode_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ['TESTING'] = '1'
|
||||
self._random_state = 1
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Snode_test, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
pass
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from sklearn.datasets import make_classification
|
||||
import os
|
||||
import numpy as np
|
||||
import csv
|
||||
|
||||
@@ -10,12 +11,20 @@ from trees.Stree import Stree, Snode
|
||||
class Stree_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ['TESTING'] = '1'
|
||||
self._random_state = 1
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=False)
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Stree_test, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
pass
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
@@ -112,9 +121,11 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertEqual(yp[0], y[0])
|
||||
|
||||
def test_multiple_prediction(self):
|
||||
# First 27 elements the predictions are the same as the truth
|
||||
num = 27
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict(X[:23, :])
|
||||
self.assertListEqual(y[:23].tolist(), yp.tolist())
|
||||
yp = self._clf.predict(X[:num, :])
|
||||
self.assertListEqual(y[:num].tolist(), yp.tolist())
|
||||
|
||||
def test_score(self):
|
||||
X, y = self._get_Xy()
|
||||
@@ -123,3 +134,26 @@ class Stree_test(unittest.TestCase):
|
||||
right = (yp == y).astype(int)
|
||||
accuracy_computed = sum(right) / len(y)
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
|
||||
def test_single_predict_proba(self):
|
||||
# Element 28 has a different prediction than the truth
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1]))
|
||||
self.assertEqual(0, yp[0:, 0])
|
||||
self.assertEqual(0.9282970550576184, yp[0:, 1])
|
||||
|
||||
def test_multiple_predict_proba(self):
|
||||
# First 27 elements the predictions are the same as the truth
|
||||
num = 27
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[:num, :])
|
||||
self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist())
|
||||
expected_proba = [0.9759887, 0.92829706, 0.9759887, 0.92829706, 0.92829706, 0.9759887,
|
||||
0.92829706, 0.9759887, 0.9759887, 0.9759887, 0.9759887, 0.92829706,
|
||||
0.92829706, 0.9759887, 0.92829706, 0.92829706, 0.92829706, 0.92829706,
|
||||
0.9759887, 0.92829706, 0.9759887, 0.92829706, 0.92829706, 0.92829706,
|
||||
0.92829706, 0.92829706, 0.9759887 ]
|
||||
self.assertListEqual(expected_proba, np.round(yp[:, 1], decimals=8).tolist())
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user