From 9e8d03d08844dd6d442385bd66019abbb72d63d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Tue, 31 May 2022 23:46:12 +0200 Subject: [PATCH] Add predict_proba test --- stree/tests/Stree_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 3a4247a..0f13eaa 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -115,6 +115,38 @@ class Stree_test(unittest.TestCase): yp = clf.fit(X, y).predict(X[:num, :]) self.assertListEqual(y[:num].tolist(), yp.tolist()) + def test_multiple_predict_proba(self): + expected = { + "liblinear": { + 0: [0.02401129943502825, 0.9759887005649718], + 17: [0.9282970550576184, 0.07170294494238157], + }, + "linear": { + 0: [0.029329608938547486, 0.9706703910614525], + 17: [0.9298469387755102, 0.07015306122448979], + }, + "rbf": { + 0: [0.023448275862068966, 0.976551724137931], + 17: [0.9458064516129032, 0.05419354838709677], + }, + "poly": { + 0: [0.01601164483260553, 0.9839883551673945], + 17: [0.9089790897908979, 0.0910209102091021], + }, + } + indices = [0, 17] + X, y = load_dataset(self._random_state) + for kernel in ["liblinear", "linear", "rbf", "poly"]: + clf = Stree( + kernel=kernel, + multiclass_strategy="ovr" if kernel == "liblinear" else "ovo", + random_state=self._random_state, + ) + yp = clf.fit(X, y).predict_proba(X) + for index in indices: + for exp, comp in zip(expected[kernel][index], yp[index]): + self.assertAlmostEqual(exp, comp) + def test_single_vs_multiple_prediction(self): """Check if predicting sample by sample gives the same result as predicting all samples at once