Grapher working

This commit is contained in:
2020-05-20 14:26:55 +02:00
parent c0ef71f139
commit 6e35628c85
6 changed files with 326 additions and 267 deletions

View File

@@ -144,15 +144,17 @@ class Stree_test(unittest.TestCase):
"""Check that element 28 has a prediction different that the current label
"""
# Element 28 has a different prediction than the truth
decimals = 8
X, y = self._get_Xy()
yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1]))
self.assertEqual(0, yp[0:, 0])
self.assertEqual(1, y[28])
self.assertEqual(0.29026400766, round(yp[0, 1], 11))
self.assertEqual(round(0.29026400766, decimals), round(yp[0, 1], decimals))
def test_multiple_predict_proba(self):
# First 27 elements the predictions are the same as the truth
num = 27
decimals = 8
X, y = self._get_Xy()
yp = self._clf.predict_proba(X[:num, :])
self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist())
@@ -161,7 +163,9 @@ class Stree_test(unittest.TestCase):
0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355,
0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692,
0.29788483, 0.01657364, 0.81149083]
self.assertListEqual(expected_proba, np.round(yp[:, 1], decimals=8).tolist())
self.assertListEqual(
np.round(expected_proba, decimals=decimals).tolist(),
np.round(yp[:, 1], decimals=decimals).tolist())
def build_models(self):
"""Build and train two models, model_clf will use the sklearn classifier to