mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
Add version of the model method
This commit is contained in:
4
setup.py
4
setup.py
@@ -1,4 +1,5 @@
|
||||
import setuptools
|
||||
import os
|
||||
|
||||
|
||||
def readme():
|
||||
@@ -8,7 +9,8 @@ def readme():
|
||||
|
||||
def get_data(field):
|
||||
item = ""
|
||||
with open("stree/__init__.py") as f:
|
||||
file_name = "_version.py" if field == "version" else "__init__.py"
|
||||
with open(os.path.join("stree", file_name)) as f:
|
||||
for line in f.readlines():
|
||||
if line.startswith(f"__{field}__"):
|
||||
delim = '"' if '"' in line else "'"
|
||||
|
@@ -17,6 +17,7 @@ from sklearn.utils.validation import (
|
||||
_check_sample_weight,
|
||||
)
|
||||
from .Splitter import Splitter, Snode, Siterator
|
||||
from ._version import __version__
|
||||
|
||||
|
||||
class Stree(BaseEstimator, ClassifierMixin):
|
||||
@@ -169,6 +170,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.normalize = normalize
|
||||
self.multiclass_strategy = multiclass_strategy
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
"""Return the version of the package."""
|
||||
return __version__
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
make mandatory the labels array
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from .Strees import Stree, Siterator
|
||||
|
||||
__version__ = "1.2.3"
|
||||
from ._version import __version__
|
||||
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
||||
|
1
stree/_version.py
Normal file
1
stree/_version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.2.3"
|
@@ -10,6 +10,7 @@ from sklearn.svm import LinearSVC
|
||||
from stree import Stree
|
||||
from stree.Splitter import Snode
|
||||
from .utils import load_dataset
|
||||
from .._version import __version__
|
||||
|
||||
|
||||
class Stree_test(unittest.TestCase):
|
||||
@@ -661,3 +662,7 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples")
|
||||
with self.assertRaises(ValueError):
|
||||
clf.fit(X, y)
|
||||
|
||||
def test_version(self):
|
||||
clf = Stree()
|
||||
self.assertEqual(__version__, clf.version())
|
||||
|
Reference in New Issue
Block a user