mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 00:46:02 +00:00
Add pyproject.toml install information
Add __call__ method to support sklearn ensembles requirements for base estimators Update tests
This commit is contained in:
@@ -174,6 +174,10 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""Return the version of the package."""
|
||||
return __version__
|
||||
|
||||
def __call__(self) -> str:
|
||||
"""Only added to comply with scikit-learn base estimator for ensemble"""
|
||||
return self.version()
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
make mandatory the labels array
|
||||
|
@@ -1,8 +1,9 @@
|
||||
from .Strees import Stree, Siterator
|
||||
from ._version import __version__
|
||||
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT License"
|
||||
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
||||
|
||||
__all__ = ["Stree", "Siterator"]
|
||||
__all__ = ["__version__", "Stree", "Siterator"]
|
||||
|
@@ -1 +1 @@
|
||||
__version__ = "1.3.2"
|
||||
__version__ = "1.4.0"
|
||||
|
@@ -289,12 +289,12 @@ class Stree_test(unittest.TestCase):
|
||||
"impurity sigmoid": 0.824,
|
||||
},
|
||||
"Iris": {
|
||||
"max_samples liblinear": 0.9550561797752809,
|
||||
"max_samples liblinear": 0.9887640449438202,
|
||||
"max_samples linear": 1.0,
|
||||
"max_samples rbf": 0.6685393258426966,
|
||||
"max_samples poly": 0.6853932584269663,
|
||||
"max_samples sigmoid": 0.6404494382022472,
|
||||
"impurity liblinear": 0.9550561797752809,
|
||||
"impurity liblinear": 0.9887640449438202,
|
||||
"impurity linear": 1.0,
|
||||
"impurity rbf": 0.6685393258426966,
|
||||
"impurity poly": 0.6853932584269663,
|
||||
@@ -440,10 +440,10 @@ class Stree_test(unittest.TestCase):
|
||||
clf.fit(X, y)
|
||||
score = clf.score(X, y)
|
||||
# Check accuracy of the whole model
|
||||
self.assertAlmostEquals(0.98, score, 5)
|
||||
self.assertAlmostEqual(0.98, score, 5)
|
||||
svm = LinearSVC(random_state=0)
|
||||
svm.fit(X, y)
|
||||
self.assertAlmostEquals(0.9666666666666667, svm.score(X, y), 5)
|
||||
self.assertAlmostEqual(0.9666666666666667, svm.score(X, y), 5)
|
||||
data = svm.decision_function(X)
|
||||
expected = [
|
||||
0.4444444444444444,
|
||||
@@ -455,7 +455,7 @@ class Stree_test(unittest.TestCase):
|
||||
ty[data > 0] = 1
|
||||
ty = ty.astype(int)
|
||||
for i in range(3):
|
||||
self.assertAlmostEquals(
|
||||
self.assertAlmostEqual(
|
||||
expected[i],
|
||||
clf.splitter_._gini(ty[:, i]),
|
||||
)
|
||||
@@ -593,7 +593,7 @@ class Stree_test(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(0.9526666666666667, clf2.fit(X, y).score(X, y))
|
||||
X, y = load_wine(return_X_y=True)
|
||||
self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
|
||||
self.assertEqual(0.9887640449438202, clf.fit(X, y).score(X, y))
|
||||
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))
|
||||
|
||||
def test_zero_all_sample_weights(self):
|
||||
@@ -725,6 +725,11 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree()
|
||||
self.assertEqual(__version__, clf.version())
|
||||
|
||||
def test_call(self) -> None:
|
||||
"""Check call method."""
|
||||
clf = Stree()
|
||||
self.assertEqual(__version__, clf())
|
||||
|
||||
def test_graph(self):
|
||||
"""Check graphviz representation of the tree."""
|
||||
X, y = load_wine(return_X_y=True)
|
||||
|
Reference in New Issue
Block a user