From 00ed57c015558748cf7663f1b71fd42e9e3716e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Fri, 17 Dec 2021 11:01:09 +0100 Subject: [PATCH] Add version of the model method --- setup.py | 4 +++- stree/Strees.py | 6 ++++++ stree/__init__.py | 3 +-- stree/_version.py | 1 + stree/tests/Stree_test.py | 5 +++++ 5 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 stree/_version.py diff --git a/setup.py b/setup.py index ee60a5c..b58d071 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import setuptools +import os def readme(): @@ -8,7 +9,8 @@ def readme(): def get_data(field): item = "" - with open("stree/__init__.py") as f: + file_name = "_version.py" if field == "version" else "__init__.py" + with open(os.path.join("stree", file_name)) as f: for line in f.readlines(): if line.startswith(f"__{field}__"): delim = '"' if '"' in line else "'" diff --git a/stree/Strees.py b/stree/Strees.py index 1857bca..ccb1044 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -17,6 +17,7 @@ from sklearn.utils.validation import ( _check_sample_weight, ) from .Splitter import Splitter, Snode, Siterator +from ._version import __version__ class Stree(BaseEstimator, ClassifierMixin): @@ -169,6 +170,11 @@ class Stree(BaseEstimator, ClassifierMixin): self.normalize = normalize self.multiclass_strategy = multiclass_strategy + @staticmethod + def version() -> str: + """Return the version of the package.""" + return __version__ + def _more_tags(self) -> dict: """Required by sklearn to supply features of the classifier make mandatory the labels array diff --git a/stree/__init__.py b/stree/__init__.py index 546ee03..e414209 100644 --- a/stree/__init__.py +++ b/stree/__init__.py @@ -1,6 +1,5 @@ from .Strees import Stree, Siterator - -__version__ = "1.2.3" +from ._version import __version__ __author__ = "Ricardo Montañana Gómez" __copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez" diff --git a/stree/_version.py b/stree/_version.py new file mode 100644 index 0000000..10aa336 --- /dev/null +++ b/stree/_version.py @@ -0,0 +1 @@ +__version__ = "1.2.3" diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index af43cdd..3813de4 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -10,6 +10,7 @@ from sklearn.svm import LinearSVC from stree import Stree from stree.Splitter import Snode from .utils import load_dataset +from .._version import __version__ class Stree_test(unittest.TestCase): @@ -661,3 +662,7 @@ class Stree_test(unittest.TestCase): clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples") with self.assertRaises(ValueError): clf.fit(X, y) + + def test_version(self): + clf = Stree() + self.assertEqual(__version__, clf.version())