Add pyproject.toml install information

Add __call__ method to support sklearn ensembles requirements for base estimators
Update tests
This commit is contained in:
2024-08-13 13:28:32 +02:00
parent 5f8ca8f3bb
commit b627bb7531
7 changed files with 83 additions and 65 deletions

View File

@@ -174,6 +174,10 @@ class Stree(BaseEstimator, ClassifierMixin):
"""Return the version of the package."""
return __version__
def __call__(self) -> str:
"""Only added to comply with scikit-learn base estimator for ensemble"""
return self.version()
def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier
make mandatory the labels array

View File

@@ -1,8 +1,9 @@
from .Strees import Stree, Siterator
from ._version import __version__
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
__license__ = "MIT License"
__author_email__ = "ricardo.montanana@alu.uclm.es"
__all__ = ["Stree", "Siterator"]
__all__ = ["__version__", "Stree", "Siterator"]

View File

@@ -1 +1 @@
__version__ = "1.3.2"
__version__ = "1.4.0"

View File

@@ -289,12 +289,12 @@ class Stree_test(unittest.TestCase):
"impurity sigmoid": 0.824,
},
"Iris": {
"max_samples liblinear": 0.9550561797752809,
"max_samples liblinear": 0.9887640449438202,
"max_samples linear": 1.0,
"max_samples rbf": 0.6685393258426966,
"max_samples poly": 0.6853932584269663,
"max_samples sigmoid": 0.6404494382022472,
"impurity liblinear": 0.9550561797752809,
"impurity liblinear": 0.9887640449438202,
"impurity linear": 1.0,
"impurity rbf": 0.6685393258426966,
"impurity poly": 0.6853932584269663,
@@ -440,10 +440,10 @@ class Stree_test(unittest.TestCase):
clf.fit(X, y)
score = clf.score(X, y)
# Check accuracy of the whole model
self.assertAlmostEquals(0.98, score, 5)
self.assertAlmostEqual(0.98, score, 5)
svm = LinearSVC(random_state=0)
svm.fit(X, y)
self.assertAlmostEquals(0.9666666666666667, svm.score(X, y), 5)
self.assertAlmostEqual(0.9666666666666667, svm.score(X, y), 5)
data = svm.decision_function(X)
expected = [
0.4444444444444444,
@@ -455,7 +455,7 @@ class Stree_test(unittest.TestCase):
ty[data > 0] = 1
ty = ty.astype(int)
for i in range(3):
self.assertAlmostEquals(
self.assertAlmostEqual(
expected[i],
clf.splitter_._gini(ty[:, i]),
)
@@ -593,7 +593,7 @@ class Stree_test(unittest.TestCase):
)
self.assertEqual(0.9526666666666667, clf2.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
self.assertEqual(0.9887640449438202, clf.fit(X, y).score(X, y))
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))
def test_zero_all_sample_weights(self):
@@ -725,6 +725,11 @@ class Stree_test(unittest.TestCase):
clf = Stree()
self.assertEqual(__version__, clf.version())
def test_call(self) -> None:
"""Check call method."""
clf = Stree()
self.assertEqual(__version__, clf())
def test_graph(self):
"""Check graphviz representation of the tree."""
X, y = load_wine(return_X_y=True)