diff --git a/tests/Stree_test.py b/tests/Stree_test.py index e3736d5..32ebca6 100644 --- a/tests/Stree_test.py +++ b/tests/Stree_test.py @@ -144,17 +144,21 @@ class Stree_test(unittest.TestCase): """Check that element 28 has a prediction different that the current label """ # Element 28 has a different prediction than the truth - decimals = 8 + decimals = 5 X, y = self._get_Xy() yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1])) self.assertEqual(0, yp[0:, 0]) self.assertEqual(1, y[28]) - self.assertEqual(round(0.29026400766, decimals), round(yp[0, 1], decimals)) + self.assertAlmostEqual( + round(0.29026400766, decimals), + round(yp[0, 1], decimals), + decimals + ) def test_multiple_predict_proba(self): # First 27 elements the predictions are the same as the truth num = 27 - decimals = 8 + decimals = 5 X, y = self._get_Xy() yp = self._clf.predict_proba(X[:num, :]) self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist()) @@ -163,9 +167,10 @@ class Stree_test(unittest.TestCase): 0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355, 0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692, 0.29788483, 0.01657364, 0.81149083] - self.assertListEqual( - np.round(expected_proba, decimals=decimals).tolist(), - np.round(yp[:, 1], decimals=decimals).tolist()) + expected = np.round(expected_proba, decimals=decimals).tolist() + computed = np.round(yp[:, 1], decimals=decimals).tolist() + for i in range(len(expected)): + self.assertAlmostEqual(expected[i], computed[i], decimals) def build_models(self): """Build and train two models, model_clf will use the sklearn classifier to