Add version of the model method

This commit is contained in:
2021-12-17 11:01:09 +01:00
parent 08222f109e
commit 00ed57c015
5 changed files with 16 additions and 3 deletions

View File

@@ -17,6 +17,7 @@ from sklearn.utils.validation import (
_check_sample_weight,
)
from .Splitter import Splitter, Snode, Siterator
from ._version import __version__
class Stree(BaseEstimator, ClassifierMixin):
@@ -169,6 +170,11 @@ class Stree(BaseEstimator, ClassifierMixin):
self.normalize = normalize
self.multiclass_strategy = multiclass_strategy
@staticmethod
def version() -> str:
"""Return the version of the package."""
return __version__
def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier
make mandatory the labels array

View File

@@ -1,6 +1,5 @@
from .Strees import Stree, Siterator
__version__ = "1.2.3"
from ._version import __version__
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"

1
stree/_version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "1.2.3"

View File

@@ -10,6 +10,7 @@ from sklearn.svm import LinearSVC
from stree import Stree
from stree.Splitter import Snode
from .utils import load_dataset
from .._version import __version__
class Stree_test(unittest.TestCase):
@@ -661,3 +662,7 @@ class Stree_test(unittest.TestCase):
clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples")
with self.assertRaises(ValueError):
clf.fit(X, y)
def test_version(self):
clf = Stree()
self.assertEqual(__version__, clf.version())