mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Add version of the model method
This commit is contained in:
@@ -17,6 +17,7 @@ from sklearn.utils.validation import (
|
||||
_check_sample_weight,
|
||||
)
|
||||
from .Splitter import Splitter, Snode, Siterator
|
||||
from ._version import __version__
|
||||
|
||||
|
||||
class Stree(BaseEstimator, ClassifierMixin):
|
||||
@@ -169,6 +170,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.normalize = normalize
|
||||
self.multiclass_strategy = multiclass_strategy
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
"""Return the version of the package."""
|
||||
return __version__
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
make mandatory the labels array
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from .Strees import Stree, Siterator
|
||||
|
||||
__version__ = "1.2.3"
|
||||
from ._version import __version__
|
||||
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
||||
|
1
stree/_version.py
Normal file
1
stree/_version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.2.3"
|
@@ -10,6 +10,7 @@ from sklearn.svm import LinearSVC
|
||||
from stree import Stree
|
||||
from stree.Splitter import Snode
|
||||
from .utils import load_dataset
|
||||
from .._version import __version__
|
||||
|
||||
|
||||
class Stree_test(unittest.TestCase):
|
||||
@@ -661,3 +662,7 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples")
|
||||
with self.assertRaises(ValueError):
|
||||
clf.fit(X, y)
|
||||
|
||||
def test_version(self):
|
||||
clf = Stree()
|
||||
self.assertEqual(__version__, clf.version())
|
||||
|
Reference in New Issue
Block a user