diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 9945016..01ca3bc 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -192,33 +192,25 @@ class Odte_test(unittest.TestCase): tclf = Odte( base_estimator=Stree(), random_state=self._random_state, - n_estimators=3, + n_estimators=5, n_jobs=1, ) - X, y = load_dataset(self._random_state, n_features=16, n_samples=500) - tclf.fit(X, y) - self.assertAlmostEqual(6.333333333333333, tclf.depth_) - self.assertAlmostEqual(10.0, tclf.leaves_) - self.assertAlmostEqual(19.0, tclf.nodes_) - nodes, leaves = tclf.nodes_leaves() - self.assertAlmostEqual(10.0, leaves) - self.assertAlmostEqual(19, nodes) - - def test_nodes_leaves_depth_parallel(self): - tclf = Odte( + tclf_p = Odte( base_estimator=Stree(), random_state=self._random_state, - n_estimators=3, + n_estimators=5, n_jobs=-1, ) X, y = load_dataset(self._random_state, n_features=16, n_samples=500) tclf.fit(X, y) - self.assertAlmostEqual(6.333333333333333, tclf.depth_) - self.assertAlmostEqual(10.0, tclf.leaves_) - self.assertAlmostEqual(19.0, tclf.nodes_) - nodes, leaves = tclf.nodes_leaves() - self.assertAlmostEqual(10.0, leaves) - self.assertAlmostEqual(19, nodes) + tclf_p.fit(X, y) + for clf in [tclf, tclf_p]: + self.assertAlmostEqual(5.8, clf.depth_) + self.assertAlmostEqual(9.4, clf.leaves_) + self.assertAlmostEqual(17.8, clf.nodes_) + nodes, leaves = clf.nodes_leaves() + self.assertAlmostEqual(9.4, leaves) + self.assertAlmostEqual(17.8, nodes) def test_nodes_leaves_SVC(self): tclf = Odte(