mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 00:16:07 +00:00
#6 - Update tests and codecov conf
This commit is contained in:
@@ -3,9 +3,6 @@ overage:
|
|||||||
project:
|
project:
|
||||||
default:
|
default:
|
||||||
target: 90%
|
target: 90%
|
||||||
patch:
|
|
||||||
default:
|
|
||||||
target: 90%
|
|
||||||
comment:
|
comment:
|
||||||
layout: "reach, diff, flags, files"
|
layout: "reach, diff, flags, files"
|
||||||
behavior: default
|
behavior: default
|
||||||
|
@@ -19,7 +19,6 @@ from sklearn.utils.validation import (
|
|||||||
check_is_fitted,
|
check_is_fitted,
|
||||||
_check_sample_weight,
|
_check_sample_weight,
|
||||||
)
|
)
|
||||||
from sklearn.utils.sparsefuncs import count_nonzero
|
|
||||||
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
||||||
|
|
||||||
|
|
||||||
@@ -422,11 +421,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
# Compute accuracy for each possible representation
|
# Compute accuracy for each possible representation
|
||||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||||
check_consistent_length(y_true, y_pred, sample_weight)
|
check_consistent_length(y_true, y_pred, sample_weight)
|
||||||
if y_type.startswith("multilabel"):
|
score = y_true == y_pred
|
||||||
differing_labels = count_nonzero(y_true - y_pred, axis=1)
|
|
||||||
score = differing_labels == 0
|
|
||||||
else:
|
|
||||||
score = y_true == y_pred
|
|
||||||
return _weighted_sum(score, sample_weight, normalize=True)
|
return _weighted_sum(score, sample_weight, normalize=True)
|
||||||
|
|
||||||
def __iter__(self) -> Siterator:
|
def __iter__(self) -> Siterator:
|
||||||
|
@@ -68,6 +68,11 @@ class Stree_grapher_test(unittest.TestCase):
|
|||||||
self.assertEqual(accuracy_score, accuracy_computed)
|
self.assertEqual(accuracy_score, accuracy_computed)
|
||||||
self.assertGreater(accuracy_score, 0.86)
|
self.assertGreater(accuracy_score, 0.86)
|
||||||
|
|
||||||
|
def test_score_4dims(self):
|
||||||
|
X, y = get_dataset(self._random_state, n_features=4)
|
||||||
|
accuracy_score = self._clf.score(X, y)
|
||||||
|
self.assertEqual(accuracy_score, 0.95)
|
||||||
|
|
||||||
def test_save_all(self):
|
def test_save_all(self):
|
||||||
folder_name = os.path.join(os.sep, "tmp", "stree")
|
folder_name = os.path.join(os.sep, "tmp", "stree")
|
||||||
if os.path.isdir(folder_name):
|
if os.path.isdir(folder_name):
|
||||||
@@ -171,11 +176,13 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
|
|
||||||
def test_plot_hyperplane_with_distribution(self):
|
def test_plot_hyperplane_with_distribution(self):
|
||||||
plt.close()
|
plt.close()
|
||||||
|
# select a pure node
|
||||||
|
node = self._clf._tree_gr.get_down().get_up().get_up()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
num_figures_before = plt.gcf().number
|
num_figures_before = plt.gcf().number
|
||||||
self._clf._tree_gr.plot_hyperplane(plot_distribution=True)
|
node.plot_hyperplane(plot_distribution=True)
|
||||||
num_figures_after = plt.gcf().number
|
num_figures_after = plt.gcf().number
|
||||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||||
|
|
||||||
@@ -209,3 +216,11 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
self.assertEqual(x, xx)
|
self.assertEqual(x, xx)
|
||||||
self.assertEqual(y, yy)
|
self.assertEqual(y, yy)
|
||||||
self.assertEqual(z, zz)
|
self.assertEqual(z, zz)
|
||||||
|
|
||||||
|
def test_cmap_change(self):
|
||||||
|
node = Snode_graph(Snode(None, None, None, "test"))
|
||||||
|
self.assertEqual("jet", node._get_cmap())
|
||||||
|
# make node pure
|
||||||
|
node._belief = 1.0
|
||||||
|
node._class = 1
|
||||||
|
self.assertEqual("jet_r", node._get_cmap())
|
||||||
|
Reference in New Issue
Block a user