mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 00:46:02 +00:00
Add predict_proba test
This commit is contained in:
@@ -115,6 +115,38 @@ class Stree_test(unittest.TestCase):
|
|||||||
yp = clf.fit(X, y).predict(X[:num, :])
|
yp = clf.fit(X, y).predict(X[:num, :])
|
||||||
self.assertListEqual(y[:num].tolist(), yp.tolist())
|
self.assertListEqual(y[:num].tolist(), yp.tolist())
|
||||||
|
|
||||||
|
def test_multiple_predict_proba(self):
|
||||||
|
expected = {
|
||||||
|
"liblinear": {
|
||||||
|
0: [0.02401129943502825, 0.9759887005649718],
|
||||||
|
17: [0.9282970550576184, 0.07170294494238157],
|
||||||
|
},
|
||||||
|
"linear": {
|
||||||
|
0: [0.029329608938547486, 0.9706703910614525],
|
||||||
|
17: [0.9298469387755102, 0.07015306122448979],
|
||||||
|
},
|
||||||
|
"rbf": {
|
||||||
|
0: [0.023448275862068966, 0.976551724137931],
|
||||||
|
17: [0.9458064516129032, 0.05419354838709677],
|
||||||
|
},
|
||||||
|
"poly": {
|
||||||
|
0: [0.01601164483260553, 0.9839883551673945],
|
||||||
|
17: [0.9089790897908979, 0.0910209102091021],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
indices = [0, 17]
|
||||||
|
X, y = load_dataset(self._random_state)
|
||||||
|
for kernel in ["liblinear", "linear", "rbf", "poly"]:
|
||||||
|
clf = Stree(
|
||||||
|
kernel=kernel,
|
||||||
|
multiclass_strategy="ovr" if kernel == "liblinear" else "ovo",
|
||||||
|
random_state=self._random_state,
|
||||||
|
)
|
||||||
|
yp = clf.fit(X, y).predict_proba(X)
|
||||||
|
for index in indices:
|
||||||
|
for exp, comp in zip(expected[kernel][index], yp[index]):
|
||||||
|
self.assertAlmostEqual(exp, comp)
|
||||||
|
|
||||||
def test_single_vs_multiple_prediction(self):
|
def test_single_vs_multiple_prediction(self):
|
||||||
"""Check if predicting sample by sample gives the same result as
|
"""Check if predicting sample by sample gives the same result as
|
||||||
predicting all samples at once
|
predicting all samples at once
|
||||||
|
Reference in New Issue
Block a user