mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
Add version of the model method
This commit is contained in:
4
setup.py
4
setup.py
@@ -1,4 +1,5 @@
|
|||||||
import setuptools
|
import setuptools
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def readme():
|
def readme():
|
||||||
@@ -8,7 +9,8 @@ def readme():
|
|||||||
|
|
||||||
def get_data(field):
|
def get_data(field):
|
||||||
item = ""
|
item = ""
|
||||||
with open("stree/__init__.py") as f:
|
file_name = "_version.py" if field == "version" else "__init__.py"
|
||||||
|
with open(os.path.join("stree", file_name)) as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
if line.startswith(f"__{field}__"):
|
if line.startswith(f"__{field}__"):
|
||||||
delim = '"' if '"' in line else "'"
|
delim = '"' if '"' in line else "'"
|
||||||
|
@@ -17,6 +17,7 @@ from sklearn.utils.validation import (
|
|||||||
_check_sample_weight,
|
_check_sample_weight,
|
||||||
)
|
)
|
||||||
from .Splitter import Splitter, Snode, Siterator
|
from .Splitter import Splitter, Snode, Siterator
|
||||||
|
from ._version import __version__
|
||||||
|
|
||||||
|
|
||||||
class Stree(BaseEstimator, ClassifierMixin):
|
class Stree(BaseEstimator, ClassifierMixin):
|
||||||
@@ -169,6 +170,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.multiclass_strategy = multiclass_strategy
|
self.multiclass_strategy = multiclass_strategy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def version() -> str:
|
||||||
|
"""Return the version of the package."""
|
||||||
|
return __version__
|
||||||
|
|
||||||
def _more_tags(self) -> dict:
|
def _more_tags(self) -> dict:
|
||||||
"""Required by sklearn to supply features of the classifier
|
"""Required by sklearn to supply features of the classifier
|
||||||
make mandatory the labels array
|
make mandatory the labels array
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
from .Strees import Stree, Siterator
|
from .Strees import Stree, Siterator
|
||||||
|
from ._version import __version__
|
||||||
__version__ = "1.2.3"
|
|
||||||
|
|
||||||
__author__ = "Ricardo Montañana Gómez"
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
||||||
|
1
stree/_version.py
Normal file
1
stree/_version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "1.2.3"
|
@@ -10,6 +10,7 @@ from sklearn.svm import LinearSVC
|
|||||||
from stree import Stree
|
from stree import Stree
|
||||||
from stree.Splitter import Snode
|
from stree.Splitter import Snode
|
||||||
from .utils import load_dataset
|
from .utils import load_dataset
|
||||||
|
from .._version import __version__
|
||||||
|
|
||||||
|
|
||||||
class Stree_test(unittest.TestCase):
|
class Stree_test(unittest.TestCase):
|
||||||
@@ -661,3 +662,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples")
|
clf = Stree(multiclass_strategy="ovo", split_criteria="max_samples")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
|
|
||||||
|
def test_version(self):
|
||||||
|
clf = Stree()
|
||||||
|
self.assertEqual(__version__, clf.version())
|
||||||
|
Reference in New Issue
Block a user