mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
#6 - Update tests and codecov conf
This commit is contained in:
@@ -3,9 +3,6 @@ overage:
|
||||
project:
|
||||
default:
|
||||
target: 90%
|
||||
patch:
|
||||
default:
|
||||
target: 90%
|
||||
comment:
|
||||
layout: "reach, diff, flags, files"
|
||||
behavior: default
|
||||
|
@@ -19,7 +19,6 @@ from sklearn.utils.validation import (
|
||||
check_is_fitted,
|
||||
_check_sample_weight,
|
||||
)
|
||||
from sklearn.utils.sparsefuncs import count_nonzero
|
||||
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
||||
|
||||
|
||||
@@ -422,11 +421,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# Compute accuracy for each possible representation
|
||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||
check_consistent_length(y_true, y_pred, sample_weight)
|
||||
if y_type.startswith("multilabel"):
|
||||
differing_labels = count_nonzero(y_true - y_pred, axis=1)
|
||||
score = differing_labels == 0
|
||||
else:
|
||||
score = y_true == y_pred
|
||||
score = y_true == y_pred
|
||||
return _weighted_sum(score, sample_weight, normalize=True)
|
||||
|
||||
def __iter__(self) -> Siterator:
|
||||
|
@@ -68,6 +68,11 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
self.assertGreater(accuracy_score, 0.86)
|
||||
|
||||
def test_score_4dims(self):
|
||||
X, y = get_dataset(self._random_state, n_features=4)
|
||||
accuracy_score = self._clf.score(X, y)
|
||||
self.assertEqual(accuracy_score, 0.95)
|
||||
|
||||
def test_save_all(self):
|
||||
folder_name = os.path.join(os.sep, "tmp", "stree")
|
||||
if os.path.isdir(folder_name):
|
||||
@@ -171,11 +176,13 @@ class Snode_graph_test(unittest.TestCase):
|
||||
|
||||
def test_plot_hyperplane_with_distribution(self):
|
||||
plt.close()
|
||||
# select a pure node
|
||||
node = self._clf._tree_gr.get_down().get_up().get_up()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
num_figures_before = plt.gcf().number
|
||||
self._clf._tree_gr.plot_hyperplane(plot_distribution=True)
|
||||
node.plot_hyperplane(plot_distribution=True)
|
||||
num_figures_after = plt.gcf().number
|
||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||
|
||||
@@ -209,3 +216,11 @@ class Snode_graph_test(unittest.TestCase):
|
||||
self.assertEqual(x, xx)
|
||||
self.assertEqual(y, yy)
|
||||
self.assertEqual(z, zz)
|
||||
|
||||
def test_cmap_change(self):
|
||||
node = Snode_graph(Snode(None, None, None, "test"))
|
||||
self.assertEqual("jet", node._get_cmap())
|
||||
# make node pure
|
||||
node._belief = 1.0
|
||||
node._class = 1
|
||||
self.assertEqual("jet_r", node._get_cmap())
|
||||
|
Reference in New Issue
Block a user