Compare commits

...

7 Commits

Author SHA1 Message Date
941c2ff5e0 Update gh action version 2024-08-14 10:15:26 +02:00
2ebf48145d Update python version requirements 2024-08-14 10:03:57 +02:00
7fbfd3622e Update python versions in gh actions 2024-08-14 09:58:36 +02:00
bc839a80d6 Remove black from lint in github actions 2024-08-14 09:52:05 +02:00
ba15ea2cc0 Remove unneeded file 2024-08-14 09:42:59 +02:00
85b56785c8 Change project builder to hatch
Update actions in Makefile
2024-08-14 09:41:45 +02:00
b627bb7531 Add pyproject.toml install information
Add __call__ method to support sklearn ensembles requirements for base estimators
Update tests
2024-08-13 13:28:32 +02:00
11 changed files with 118 additions and 102 deletions

View File

@@ -13,12 +13,12 @@ jobs:
strategy:
matrix:
os: [macos-latest, ubuntu-latest, windows-latest]
python: [3.8, "3.10"]
python: [3.11, 3.12]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
@@ -28,14 +28,14 @@ jobs:
pip install -q --upgrade codecov coverage black flake8 codacy-coverage
- name: Lint
run: |
black --check --diff stree
# black --check --diff stree
flake8 --count stree
- name: Tests
run: |
coverage run -m unittest -v stree.tests
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
include README.md LICENSE

View File

@@ -1,46 +1,36 @@
SHELL := /bin/bash
.DEFAULT_GOAL := help
.PHONY: coverage deps help lint push test doc build
.PHONY: audit coverage help lint test doc doc-clean build
coverage: ## Run tests with coverage
coverage erase
coverage run -m unittest -v stree.tests
coverage report -m
@coverage erase
@coverage run -m unittest -v stree.tests
@coverage report -m
deps: ## Install dependencies
pip install -r requirements.txt
devdeps: ## Install development dependencies
pip install black pip-audit flake8 mypy coverage
lint: ## Lint and static-check
black stree
flake8 stree
mypy stree
push: ## Push code with tags
git push && git push --tags
lint: ## Lint source files
@black stree
@flake8 stree
test: ## Run tests
python -m unittest -v stree.tests
@python -m unittest -v stree.tests
doc: ## Update documentation
make -C docs --makefile=Makefile html
@make -C docs --makefile=Makefile html
build: ## Build package
rm -fr dist/*
rm -fr build/*
python setup.py sdist bdist_wheel
@rm -fr dist/*
@rm -fr build/*
@hatch build
doc-clean: ## Update documentation
make -C docs --makefile=Makefile clean
doc-clean: ## Clean documentation folders
@make -C docs --makefile=Makefile clean
audit: ## Audit pip
pip-audit
@pip-audit
help: ## Show help message
help: ## Show this help message
@IFS=$$'\n' ; \
help_lines=(`fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \
help_lines=(`grep -Fh "##" $(MAKEFILE_LIST) | grep -Fv fgrep | sed -e 's/\\$$//' | sed -e 's/##/:/'`); \
printf "%s\n\n" "Usage: make [task]"; \
printf "%-20s %s\n" "task" "help" ; \
printf "%-20s %s\n" "------" "----" ; \

View File

@@ -1,5 +1,65 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "STree"
dependencies = ["scikit-learn>1.0", "mufs"]
license = { file = "LICENSE" }
description = "Oblique decision tree with svm nodes."
readme = "README.md"
authors = [
{ name = "Ricardo Montañana", email = "ricardo.montanana@alu.uclm.es" },
]
dynamic = ['version']
requires-python = ">=3.11"
keywords = [
"scikit-learn",
"oblique-classifier",
"oblique-decision-tree",
"decision-tree",
"svm",
"svc",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Topic :: Software Development",
"Topic :: Scientific/Engineering",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
[project.optional-dependencies]
dev = ["black", "flake8", "coverage", "hatch", "pip-audit"]
doc = ["sphinx", "myst-parser", "sphinx_rtd_theme", "sphinx-autodoc-typehints"]
[project.urls]
Code = "https://github.com/Doctorado-ML/STree"
Documentation = "https://stree.readthedocs.io/en/latest/index.html"
[tool.hatch.version]
path = "stree/_version.py"
[tool.hatch.build.targets.sdist]
include = ["/stree"]
[tool.coverage.run]
branch = true
source = ["stree"]
command_line = "-m unittest discover -s stree.tests"
[tool.coverage.report]
show_missing = true
fail_under = 100
[tool.black]
line-length = 79
target-version = ["py311"]
include = '\.pyi?$'
exclude = '''
/(
@@ -13,4 +73,4 @@ exclude = '''
| build
| dist
)/
'''
'''

View File

@@ -1 +0,0 @@
python-3.8

View File

@@ -1,56 +0,0 @@
import setuptools
import os
def readme():
with open("README.md") as f:
return f.read()
def get_data(field, file_name="__init__.py"):
item = ""
with open(os.path.join("stree", file_name)) as f:
for line in f.readlines():
if line.startswith(f"__{field}__"):
delim = '"' if '"' in line else "'"
item = line.split(delim)[1]
break
else:
raise RuntimeError(f"Unable to find {field} string.")
return item
def get_requirements():
with open("requirements.txt") as f:
return f.read().splitlines()
setuptools.setup(
name="STree",
version=get_data("version", "_version.py"),
license=get_data("license"),
description="Oblique decision tree with svm nodes",
long_description=readme(),
long_description_content_type="text/markdown",
packages=setuptools.find_packages(),
url="https://github.com/Doctorado-ML/STree#stree",
project_urls={
"Code": "https://github.com/Doctorado-ML/STree",
"Documentation": "https://stree.readthedocs.io/en/latest/index.html",
},
author=get_data("author"),
author_email=get_data("author_email"),
keywords="scikit-learn oblique-classifier oblique-decision-tree decision-\
tree svm svc",
classifiers=[
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: " + get_data("license"),
"Programming Language :: Python :: 3.8",
"Natural Language :: English",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Intended Audience :: Science/Research",
],
install_requires=get_requirements(),
test_suite="stree.tests",
zip_safe=False,
)

View File

@@ -414,7 +414,8 @@ class Splitter:
)
return tuple(
sorted(
range(len(feature_list)), key=lambda sub: feature_list[sub]
range(len(feature_list)),
key=lambda sub: feature_list[sub],
)[-max_features:]
)
@@ -529,7 +530,10 @@ class Splitter:
return entropy
def information_gain(
self, labels: np.array, labels_up: np.array, labels_dn: np.array
self,
labels: np.array,
labels_up: np.array,
labels_dn: np.array,
) -> float:
"""Compute information gain of a split candidate

View File

@@ -174,6 +174,11 @@ class Stree(BaseEstimator, ClassifierMixin):
"""Return the version of the package."""
return __version__
def __call__(self) -> str:
"""Only added to comply with scikit-learn base sestimator for ensembles
"""
return self.version()
def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier
make mandatory the labels array
@@ -184,7 +189,10 @@ class Stree(BaseEstimator, ClassifierMixin):
return {"requires_y": True}
def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.array = None,
) -> "Stree":
"""Build the tree based on the dataset of samples and its labels
@@ -339,7 +347,11 @@ class Stree(BaseEstimator, ClassifierMixin):
)
node.set_down(
self._train(
X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})"
X_D,
y_d,
sw_d,
depth + 1,
title + f" - Down({depth+1})",
)
)
return node

View File

@@ -1,8 +1,9 @@
from .Strees import Stree, Siterator
from ._version import __version__
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
__license__ = "MIT License"
__author_email__ = "ricardo.montanana@alu.uclm.es"
__all__ = ["Stree", "Siterator"]
__all__ = ["__version__", "Stree", "Siterator"]

View File

@@ -1 +1 @@
__version__ = "1.3.2"
__version__ = "1.4.0"

View File

@@ -289,12 +289,12 @@ class Stree_test(unittest.TestCase):
"impurity sigmoid": 0.824,
},
"Iris": {
"max_samples liblinear": 0.9550561797752809,
"max_samples liblinear": 0.9887640449438202,
"max_samples linear": 1.0,
"max_samples rbf": 0.6685393258426966,
"max_samples poly": 0.6853932584269663,
"max_samples sigmoid": 0.6404494382022472,
"impurity liblinear": 0.9550561797752809,
"impurity liblinear": 0.9887640449438202,
"impurity linear": 1.0,
"impurity rbf": 0.6685393258426966,
"impurity poly": 0.6853932584269663,
@@ -440,10 +440,10 @@ class Stree_test(unittest.TestCase):
clf.fit(X, y)
score = clf.score(X, y)
# Check accuracy of the whole model
self.assertAlmostEquals(0.98, score, 5)
self.assertAlmostEqual(0.98, score, 5)
svm = LinearSVC(random_state=0)
svm.fit(X, y)
self.assertAlmostEquals(0.9666666666666667, svm.score(X, y), 5)
self.assertAlmostEqual(0.9666666666666667, svm.score(X, y), 5)
data = svm.decision_function(X)
expected = [
0.4444444444444444,
@@ -455,7 +455,7 @@ class Stree_test(unittest.TestCase):
ty[data > 0] = 1
ty = ty.astype(int)
for i in range(3):
self.assertAlmostEquals(
self.assertAlmostEqual(
expected[i],
clf.splitter_._gini(ty[:, i]),
)
@@ -593,7 +593,7 @@ class Stree_test(unittest.TestCase):
)
self.assertEqual(0.9526666666666667, clf2.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
self.assertEqual(0.9887640449438202, clf.fit(X, y).score(X, y))
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))
def test_zero_all_sample_weights(self):
@@ -725,6 +725,11 @@ class Stree_test(unittest.TestCase):
clf = Stree()
self.assertEqual(__version__, clf.version())
def test_call(self) -> None:
"""Check call method."""
clf = Stree()
self.assertEqual(__version__, clf())
def test_graph(self):
"""Check graphviz representation of the tree."""
X, y = load_wine(return_X_y=True)