mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Make project python package friendly
- Add setup.py - Move classes to module files - Move tests folder inside module folder
This commit is contained in:
@@ -10,4 +10,4 @@ notifications:
|
|||||||
on_success: never # default: change
|
on_success: never # default: change
|
||||||
on_failure: always # default: always
|
on_failure: always # default: always
|
||||||
# command to run tests
|
# command to run tests
|
||||||
script: python -m unittest tests.Stree_test tests.Snode_test
|
script: python -m unittest stree.tests
|
14
README.md
14
README.md
@@ -4,7 +4,13 @@
|
|||||||
|
|
||||||
Oblique Tree classifier based on SVM nodes
|
Oblique Tree classifier based on SVM nodes
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install git+https://github.com/doctorado-ml/stree
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@@ -24,12 +30,12 @@ Oblique Tree classifier based on SVM nodes
|
|||||||
|
|
||||||
### Command line
|
### Command line
|
||||||
|
|
||||||
```python
|
```bash
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
|
|
||||||
```python
|
```bash
|
||||||
python -m unittest -v tests.Stree_test tests.Snode_test
|
python -m unittest -v stree.tests
|
||||||
```
|
```
|
||||||
|
File diff suppressed because one or more lines are too long
2
main.py
2
main.py
@@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from trees.Stree import Stree
|
from stree import Stree
|
||||||
|
|
||||||
random_state=1
|
random_state=1
|
||||||
|
|
||||||
|
40
setup.py
Normal file
40
setup.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import setuptools
|
||||||
|
|
||||||
|
from stree import __author__, __version__
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README.md') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
setuptools.setup(
|
||||||
|
name='STree',
|
||||||
|
version=__version__,
|
||||||
|
license='MIT License',
|
||||||
|
description='a python interface to oblique decision tree implementations',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
packages=['stree'],
|
||||||
|
url='https://github.com/doctorado-ml/stree',
|
||||||
|
author=__author__,
|
||||||
|
author_email='ricardo.montanana@alu.uclm.es',
|
||||||
|
keywords='scikit-learn oblique-classifier oblique-decision-tree decision-tree svm svc',
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 4 - Beta',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Natural Language :: English',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
'Intended Audience :: Science/Research'
|
||||||
|
],
|
||||||
|
install_requires=[
|
||||||
|
'scikit-learn>=0.23.0',
|
||||||
|
'numpy',
|
||||||
|
'matplotlib',
|
||||||
|
'ipympl'
|
||||||
|
],
|
||||||
|
data_files=[('data', ['data/.gitignore'])],
|
||||||
|
test_suite="stree.tests",
|
||||||
|
zip_safe=False
|
||||||
|
)
|
@@ -8,20 +8,100 @@ Uses LinearSVC
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import LinearSVC
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from trees.Snode import Snode
|
|
||||||
from trees.Siterator import Siterator
|
|
||||||
|
|
||||||
|
class Snode:
|
||||||
|
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
||||||
|
self._clf = clf
|
||||||
|
self._vector = None if clf is None else clf.coef_
|
||||||
|
self._interceptor = 0. if clf is None else clf.intercept_
|
||||||
|
self._title = title
|
||||||
|
self._belief = 0. # belief of the prediction in a leaf node based on samples
|
||||||
|
# Only store dataset in Testing
|
||||||
|
self._X = X if os.environ.get('TESTING', 'NS') != 'NS' else None
|
||||||
|
self._y = y
|
||||||
|
self._down = None
|
||||||
|
self._up = None
|
||||||
|
self._class = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def copy(cls, node: 'Snode') -> 'Snode':
|
||||||
|
return cls(node._clf, node._X, node._y, node._title)
|
||||||
|
|
||||||
|
def set_down(self, son):
|
||||||
|
self._down = son
|
||||||
|
|
||||||
|
def set_up(self, son):
|
||||||
|
self._up = son
|
||||||
|
|
||||||
|
def is_leaf(self) -> bool:
|
||||||
|
return self._up is None and self._down is None
|
||||||
|
|
||||||
|
def get_down(self) -> 'Snode':
|
||||||
|
return self._down
|
||||||
|
|
||||||
|
def get_up(self) -> 'Snode':
|
||||||
|
return self._up
|
||||||
|
|
||||||
|
def make_predictor(self):
|
||||||
|
"""Compute the class of the predictor and its belief based on the subdataset of the node
|
||||||
|
only if it is a leaf
|
||||||
|
"""
|
||||||
|
if not self.is_leaf():
|
||||||
|
return
|
||||||
|
classes, card = np.unique(self._y, return_counts=True)
|
||||||
|
if len(classes) > 1:
|
||||||
|
max_card = max(card)
|
||||||
|
min_card = min(card)
|
||||||
|
try:
|
||||||
|
self._belief = max_card / (max_card + min_card)
|
||||||
|
except:
|
||||||
|
self._belief = 0.
|
||||||
|
self._class = classes[card == max_card][0]
|
||||||
|
else:
|
||||||
|
self._belief = 1
|
||||||
|
self._class = classes[0]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.is_leaf():
|
||||||
|
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
|
||||||
|
else:
|
||||||
|
return f"{self._title}"
|
||||||
|
|
||||||
|
|
||||||
|
class Siterator:
|
||||||
|
"""Stree preorder iterator
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tree: Snode):
|
||||||
|
self._stack = []
|
||||||
|
self._push(tree)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _push(self, node: Snode):
|
||||||
|
if node is not None:
|
||||||
|
self._stack.append(node)
|
||||||
|
|
||||||
|
def __next__(self) -> Snode:
|
||||||
|
if len(self._stack) == 0:
|
||||||
|
raise StopIteration()
|
||||||
|
node = self._stack.pop()
|
||||||
|
self._push(node.get_up())
|
||||||
|
self._push(node.get_down())
|
||||||
|
return node
|
||||||
|
|
||||||
class Stree(BaseEstimator, ClassifierMixin):
|
class Stree(BaseEstimator, ClassifierMixin):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, C: float=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
||||||
self._max_iter = max_iter
|
self._max_iter = max_iter
|
||||||
self._C = C
|
self._C = C
|
||||||
self._random_state = random_state
|
self._random_state = random_state
|
||||||
@@ -77,12 +157,14 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
def _build_predictor(self):
|
def _build_predictor(self):
|
||||||
"""Process the leaves to make them predictors
|
"""Process the leaves to make them predictors
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run_tree(node: Snode):
|
def run_tree(node: Snode):
|
||||||
if node.is_leaf():
|
if node.is_leaf():
|
||||||
node.make_predictor()
|
node.make_predictor()
|
||||||
return
|
return
|
||||||
run_tree(node.get_down())
|
run_tree(node.get_down())
|
||||||
run_tree(node.get_up())
|
run_tree(node.get_up())
|
||||||
|
|
||||||
run_tree(self._tree)
|
run_tree(self._tree)
|
||||||
|
|
||||||
def train(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> Snode:
|
def train(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> Snode:
|
||||||
@@ -121,6 +203,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
k, l = predict_class(d, i_d, node.get_down())
|
k, l = predict_class(d, i_d, node.get_down())
|
||||||
m, n = predict_class(u, i_u, node.get_up())
|
m, n = predict_class(u, i_u, node.get_up())
|
||||||
return np.append(k, m), np.append(l, n)
|
return np.append(k, m), np.append(l, n)
|
||||||
|
|
||||||
# sklearn check
|
# sklearn check
|
||||||
check_is_fitted(self)
|
check_is_fitted(self)
|
||||||
# Input validation
|
# Input validation
|
||||||
@@ -136,6 +219,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
:param X: dataset
|
:param X: dataset
|
||||||
:type X: np.array
|
:type X: np.array
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def predict_class(xp: np.array, indices: np.array, dist: np.array, node: Snode) -> np.array:
|
def predict_class(xp: np.array, indices: np.array, dist: np.array, node: Snode) -> np.array:
|
||||||
"""Run the tree to compute predictions
|
"""Run the tree to compute predictions
|
||||||
|
|
||||||
@@ -161,6 +245,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
k, l = predict_class(d, i_d, r_d, node.get_down())
|
k, l = predict_class(d, i_d, r_d, node.get_down())
|
||||||
m, n = predict_class(u, i_u, r_u, node.get_up())
|
m, n = predict_class(u, i_u, r_u, node.get_up())
|
||||||
return np.append(k, m), np.append(l, n)
|
return np.append(k, m), np.append(l, n)
|
||||||
|
|
||||||
# sklearn check
|
# sklearn check
|
||||||
check_is_fitted(self)
|
check_is_fitted(self)
|
||||||
# Input validation
|
# Input validation
|
||||||
@@ -217,5 +302,10 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
def save_sub_datasets(self):
|
def save_sub_datasets(self):
|
||||||
"""Save the every dataset stored in the tree to check with manual classifier
|
"""Save the every dataset stored in the tree to check with manual classifier
|
||||||
"""
|
"""
|
||||||
|
if not os.path.isdir(self.__folder):
|
||||||
|
os.mkdir(self.__folder)
|
||||||
with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog:
|
with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog:
|
||||||
self._save_datasets(self._tree, catalog, 1)
|
self._save_datasets(self._tree, catalog, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
@@ -6,12 +6,14 @@ __version__ = "0.9"
|
|||||||
Plot 3D views of nodes in Stree
|
Plot 3D views of nodes in Stree
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import numpy as np
|
import os
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
|
||||||
from trees.Snode import Snode
|
|
||||||
from trees.Stree import Stree
|
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
from .Strees import Stree, Snode, Siterator
|
||||||
|
|
||||||
class Snode_graph(Snode):
|
class Snode_graph(Snode):
|
||||||
|
|
||||||
@@ -68,6 +70,7 @@ class Snode_graph(Snode):
|
|||||||
# get the splitting hyperplane
|
# get the splitting hyperplane
|
||||||
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
||||||
- self._vector[0][1] * y) / self._vector[0][2]
|
- self._vector[0][1] * y) / self._vector[0][2]
|
||||||
|
|
||||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||||
xx, yy = np.meshgrid(tmpx, tmpy)
|
xx, yy = np.meshgrid(tmpx, tmpy)
|
||||||
@@ -93,3 +96,87 @@ class Snode_graph(Snode):
|
|||||||
ax.set_ylabel('X1')
|
ax.set_ylabel('X1')
|
||||||
ax.set_zlabel('X2')
|
ax.set_zlabel('X2')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
class Stree_grapher(Stree):
|
||||||
|
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
|
||||||
|
make its magic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params: dict):
|
||||||
|
self._plot_size = (8, 8)
|
||||||
|
self._tree_gr = None
|
||||||
|
# make Snode store X's
|
||||||
|
os.environ['TESTING'] = '1'
|
||||||
|
self._fitted = False
|
||||||
|
self._pca = None
|
||||||
|
super().__init__(**params)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
os.environ.pop('TESTING')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||||
|
mirror = Snode_graph(node)
|
||||||
|
# clone node
|
||||||
|
mirror._class = node._class
|
||||||
|
mirror._belief = node._belief
|
||||||
|
if node.get_down() is not None:
|
||||||
|
mirror.set_down(self._copy_tree(node.get_down()))
|
||||||
|
if node.get_up() is not None:
|
||||||
|
mirror.set_up(self._copy_tree(node.get_up()))
|
||||||
|
return mirror
|
||||||
|
|
||||||
|
def fit(self, X: np.array, y: np.array) -> Stree:
|
||||||
|
"""Fit the Stree and copy the tree in a Snode_graph tree
|
||||||
|
|
||||||
|
:param X: Dataset
|
||||||
|
:type X: np.array
|
||||||
|
:param y: Labels
|
||||||
|
:type y: np.array
|
||||||
|
:return: Stree model
|
||||||
|
:rtype: Stree
|
||||||
|
"""
|
||||||
|
if X.shape[1] != 3:
|
||||||
|
self._pca = PCA(n_components=3)
|
||||||
|
X = self._pca.fit_transform(X)
|
||||||
|
res = super().fit(X, y)
|
||||||
|
self._tree_gr = self._copy_tree(self._tree)
|
||||||
|
self._fitted = True
|
||||||
|
return res
|
||||||
|
|
||||||
|
def score(self, X: np.array, y: np.array) -> float:
|
||||||
|
self._check_fitted()
|
||||||
|
if X.shape[1] != 3:
|
||||||
|
X = self._pca.transform(X)
|
||||||
|
return super().score(X, y)
|
||||||
|
|
||||||
|
def _check_fitted(self):
|
||||||
|
if not self._fitted:
|
||||||
|
raise Exception('Have to fit the grapher first!')
|
||||||
|
|
||||||
|
def save_all(self, save_folder: str = './', save_prefix: str = ''):
|
||||||
|
"""Save all the node plots in png format, each with a sequence number
|
||||||
|
|
||||||
|
:param save_folder: folder where the plots are saved, defaults to './'
|
||||||
|
:type save_folder: str, optional
|
||||||
|
"""
|
||||||
|
self._check_fitted()
|
||||||
|
seq = 1
|
||||||
|
for node in self:
|
||||||
|
node.save_hyperplane(save_folder=save_folder,
|
||||||
|
save_prefix=save_prefix, save_seq=seq)
|
||||||
|
seq += 1
|
||||||
|
|
||||||
|
def plot_all(self):
|
||||||
|
"""Plots all the nodes
|
||||||
|
"""
|
||||||
|
self._check_fitted()
|
||||||
|
for node in self:
|
||||||
|
node.plot_hyperplane()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return Siterator(self._tree_gr)
|
||||||
|
|
4
stree/__init__.py
Normal file
4
stree/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__version__ = "0.9rc1"
|
||||||
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
|
from .Strees import Stree, Snode, Siterator
|
||||||
|
from .Strees_grapher import Stree_grapher, Snode_graph
|
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
from trees.Stree import Stree, Snode
|
from stree import Stree, Snode
|
||||||
|
|
||||||
|
|
||||||
class Stree_test(unittest.TestCase):
|
class Stree_test(unittest.TestCase):
|
||||||
@@ -14,7 +14,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
os.environ['TESTING'] = '1'
|
os.environ['TESTING'] = '1'
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._clf = Stree(random_state=self._random_state,
|
self._clf = Stree(random_state=self._random_state,
|
||||||
use_predictions=False)
|
use_predictions=False)
|
||||||
self._clf.fit(*self._get_Xy())
|
self._clf.fit(*self._get_Xy())
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -163,10 +163,10 @@ class Stree_test(unittest.TestCase):
|
|||||||
yp = self._clf.predict_proba(X[:num, :])
|
yp = self._clf.predict_proba(X[:num, :])
|
||||||
self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist())
|
self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist())
|
||||||
expected_proba = [0.88395641, 0.36746962, 0.84158767, 0.34106833, 0.14269291, 0.85193236,
|
expected_proba = [0.88395641, 0.36746962, 0.84158767, 0.34106833, 0.14269291, 0.85193236,
|
||||||
0.29876058, 0.7282164, 0.85958616, 0.89517877, 0.99745224, 0.18860349,
|
0.29876058, 0.7282164, 0.85958616, 0.89517877, 0.99745224, 0.18860349,
|
||||||
0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355,
|
0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355,
|
||||||
0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692,
|
0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692,
|
||||||
0.29788483, 0.01657364, 0.81149083]
|
0.29788483, 0.01657364, 0.81149083]
|
||||||
expected = np.round(expected_proba, decimals=decimals).tolist()
|
expected = np.round(expected_proba, decimals=decimals).tolist()
|
||||||
computed = np.round(yp[:, 1], decimals=decimals).tolist()
|
computed = np.round(yp[:, 1], decimals=decimals).tolist()
|
||||||
for i in range(len(expected)):
|
for i in range(len(expected)):
|
||||||
@@ -178,9 +178,9 @@ class Stree_test(unittest.TestCase):
|
|||||||
coefficients to compute both predictions and splitted data
|
coefficients to compute both predictions and splitted data
|
||||||
"""
|
"""
|
||||||
model_clf = Stree(random_state=self._random_state,
|
model_clf = Stree(random_state=self._random_state,
|
||||||
use_predictions=True)
|
use_predictions=True)
|
||||||
model_computed = Stree(random_state=self._random_state,
|
model_computed = Stree(random_state=self._random_state,
|
||||||
use_predictions=False)
|
use_predictions=False)
|
||||||
X, y = self._get_Xy()
|
X, y = self._get_Xy()
|
||||||
model_clf.fit(X, y)
|
model_clf.fit(X, y)
|
||||||
model_computed.fit(X, y)
|
model_computed.fit(X, y)
|
||||||
@@ -201,7 +201,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
b = use_math.score(X, y)
|
b = use_math.score(X, y)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
use_clf.score(X, y),
|
use_clf.score(X, y),
|
||||||
b
|
b
|
||||||
)
|
)
|
||||||
self.assertGreater(b, .95)
|
self.assertGreater(b, .95)
|
||||||
|
|
||||||
@@ -243,7 +243,71 @@ class Stree_test(unittest.TestCase):
|
|||||||
computed.append(str(node))
|
computed.append(str(node))
|
||||||
self.assertListEqual(expected, computed)
|
self.assertListEqual(expected, computed)
|
||||||
|
|
||||||
|
class Snode_test(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
os.environ['TESTING'] = '1'
|
||||||
|
self._random_state = 1
|
||||||
|
self._clf = Stree(random_state=self._random_state,
|
||||||
|
use_predictions=True)
|
||||||
|
self._clf.fit(*self._get_Xy())
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
os.environ.pop('TESTING')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_Xy(self):
|
||||||
|
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||||
|
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||||
|
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
def test_attributes_in_leaves(self):
|
||||||
|
"""Check if the attributes in leaves have correct values so they form a predictor
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_leave(node: Snode):
|
||||||
|
if not node.is_leaf():
|
||||||
|
check_leave(node.get_down())
|
||||||
|
check_leave(node.get_up())
|
||||||
|
return
|
||||||
|
# Check Belief in leave
|
||||||
|
classes, card = np.unique(node._y, return_counts=True)
|
||||||
|
max_card = max(card)
|
||||||
|
min_card = min(card)
|
||||||
|
if len(classes) > 1:
|
||||||
|
try:
|
||||||
|
belief = max_card / (max_card + min_card)
|
||||||
|
except:
|
||||||
|
belief = 0.
|
||||||
|
else:
|
||||||
|
belief = 1
|
||||||
|
self.assertEqual(belief, node._belief)
|
||||||
|
# Check Class
|
||||||
|
class_computed = classes[card == max_card]
|
||||||
|
self.assertEqual(class_computed, node._class)
|
||||||
|
|
||||||
|
check_leave(self._clf._tree)
|
||||||
|
|
||||||
|
def test_nodes_coefs(self):
|
||||||
|
"""Check if the nodes of the tree have the right attributes filled
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run_tree(node: Snode):
|
||||||
|
if node._belief < 1:
|
||||||
|
# only exclude pure leaves
|
||||||
|
self.assertIsNotNone(node._clf)
|
||||||
|
self.assertIsNotNone(node._clf.coef_)
|
||||||
|
self.assertIsNotNone(node._vector)
|
||||||
|
self.assertIsNotNone(node._interceptor)
|
||||||
|
if node.is_leaf():
|
||||||
|
return
|
||||||
|
run_tree(node.get_down())
|
||||||
|
run_tree(node.get_up())
|
||||||
|
|
||||||
|
run_tree(self._clf._tree)
|
||||||
|
|
1
stree/tests/__init__.py
Normal file
1
stree/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .Strees_test import Stree_test, Snode_test
|
42
test.ipynb
42
test.ipynb
File diff suppressed because one or more lines are too long
15
test2.ipynb
15
test2.ipynb
@@ -26,7 +26,8 @@
|
|||||||
"from sklearn.svm import LinearSVC\n",
|
"from sklearn.svm import LinearSVC\n",
|
||||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||||
"from sklearn.datasets import make_classification, load_iris, load_wine\n",
|
"from sklearn.datasets import make_classification, load_iris, load_wine\n",
|
||||||
"from trees.Stree import Stree\n",
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from stree import Stree\n",
|
||||||
"import time"
|
"import time"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -50,14 +51,10 @@
|
|||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 32.976% 492\nValid: 67.024% 1000\n"
|
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.110% 494\nValid: 66.890% 998\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import time\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from trees.Stree import Stree\n",
|
|
||||||
"\n",
|
|
||||||
"random_state=1\n",
|
"random_state=1\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def load_creditcard(n_examples=0):\n",
|
"def load_creditcard(n_examples=0):\n",
|
||||||
@@ -105,7 +102,7 @@
|
|||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9559\nClassifier's accuracy (test) : 0.9442\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.986928 counts=(array([0, 1]), array([ 4, 302]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.943089 counts=(array([0, 1]), array([696, 42]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9397\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.993443 counts=(array([0, 1]), array([ 2, 303]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.944520 counts=(array([0, 1]), array([698, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.957359 counts=(array([0, 1]), array([696, 31]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.957182 counts=(array([0, 1]), array([693, 31]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9789\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n\n**************************************************\n0.6609 secs\n"
|
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9521\nClassifier's accuracy (test) : 0.9598\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.980519 counts=(array([0, 1]), array([ 6, 302]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.940217 counts=(array([0, 1]), array([692, 44]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9521\nClassifier's accuracy (test) : 0.9643\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.986842 counts=(array([0, 1]), array([ 4, 300]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.937754 counts=(array([0, 1]), array([693, 46]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9636\nClassifier's accuracy (test) : 0.9688\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([308]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([8]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.947802 counts=(array([0, 1]), array([690, 38]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9665\nClassifier's accuracy (test) : 0.9621\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([308]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([11]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.951456 counts=(array([0, 1]), array([686, 35]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9741\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n\n**************************************************\n0.7816 secs\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -146,7 +143,7 @@
|
|||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n"
|
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -163,7 +160,7 @@
|
|||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n"
|
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([306]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.961538 counts=(array([0, 1]), array([675, 27]))\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
File diff suppressed because one or more lines are too long
@@ -1,72 +0,0 @@
|
|||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.datasets import make_classification
|
|
||||||
|
|
||||||
from trees.Stree import Stree, Snode
|
|
||||||
|
|
||||||
|
|
||||||
class Snode_test(unittest.TestCase):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
os.environ['TESTING'] = '1'
|
|
||||||
self._random_state = 1
|
|
||||||
self._clf = Stree(random_state=self._random_state,
|
|
||||||
use_predictions=True)
|
|
||||||
self._clf.fit(*self._get_Xy())
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
os.environ.pop('TESTING')
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_Xy(self):
|
|
||||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
|
||||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
|
||||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
|
||||||
return X, y
|
|
||||||
|
|
||||||
def test_attributes_in_leaves(self):
|
|
||||||
"""Check if the attributes in leaves have correct values so they form a predictor
|
|
||||||
"""
|
|
||||||
def check_leave(node: Snode):
|
|
||||||
if not node.is_leaf():
|
|
||||||
check_leave(node.get_down())
|
|
||||||
check_leave(node.get_up())
|
|
||||||
return
|
|
||||||
# Check Belief in leave
|
|
||||||
classes, card = np.unique(node._y, return_counts=True)
|
|
||||||
max_card = max(card)
|
|
||||||
min_card = min(card)
|
|
||||||
if len(classes) > 1:
|
|
||||||
try:
|
|
||||||
belief = max_card / (max_card + min_card)
|
|
||||||
except:
|
|
||||||
belief = 0.
|
|
||||||
else:
|
|
||||||
belief = 1
|
|
||||||
self.assertEqual(belief, node._belief)
|
|
||||||
# Check Class
|
|
||||||
class_computed = classes[card == max_card]
|
|
||||||
self.assertEqual(class_computed, node._class)
|
|
||||||
check_leave(self._clf._tree)
|
|
||||||
|
|
||||||
def test_nodes_coefs(self):
|
|
||||||
"""Check if the nodes of the tree have the right attributes filled
|
|
||||||
"""
|
|
||||||
def run_tree(node: Snode):
|
|
||||||
if node._belief < 1:
|
|
||||||
# only exclude pure leaves
|
|
||||||
self.assertIsNotNone(node._clf)
|
|
||||||
self.assertIsNotNone(node._clf.coef_)
|
|
||||||
self.assertIsNotNone(node._vector)
|
|
||||||
self.assertIsNotNone(node._interceptor)
|
|
||||||
if node.is_leaf():
|
|
||||||
return
|
|
||||||
run_tree(node.get_down())
|
|
||||||
run_tree(node.get_up())
|
|
||||||
run_tree(self._clf._tree)
|
|
@@ -1,33 +0,0 @@
|
|||||||
'''
|
|
||||||
__author__ = "Ricardo Montañana Gómez"
|
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
|
||||||
__license__ = "MIT"
|
|
||||||
__version__ = "0.9"
|
|
||||||
Inorder iterator for the binary tree of Snodes
|
|
||||||
'''
|
|
||||||
|
|
||||||
from trees.Snode import Snode
|
|
||||||
|
|
||||||
|
|
||||||
class Siterator:
|
|
||||||
"""Stree preorder iterator
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tree: Snode):
|
|
||||||
self._stack = []
|
|
||||||
self._push(tree)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _push(self, node: Snode):
|
|
||||||
if node is not None:
|
|
||||||
self._stack.append(node)
|
|
||||||
|
|
||||||
def __next__(self) -> Snode:
|
|
||||||
if len(self._stack) == 0:
|
|
||||||
raise StopIteration()
|
|
||||||
node = self._stack.pop()
|
|
||||||
self._push(node.get_up())
|
|
||||||
self._push(node.get_down())
|
|
||||||
return node
|
|
@@ -1,70 +0,0 @@
|
|||||||
'''
|
|
||||||
__author__ = "Ricardo Montañana Gómez"
|
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
|
||||||
__license__ = "MIT"
|
|
||||||
__version__ = "0.9"
|
|
||||||
Node of the Stree (binary tree)
|
|
||||||
'''
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.svm import LinearSVC
|
|
||||||
|
|
||||||
class Snode:
|
|
||||||
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
|
||||||
self._clf = clf
|
|
||||||
self._vector = None if clf is None else clf.coef_
|
|
||||||
self._interceptor = 0. if clf is None else clf.intercept_
|
|
||||||
self._title = title
|
|
||||||
self._belief = 0. # belief of the prediction in a leaf node based on samples
|
|
||||||
# Only store dataset in Testing
|
|
||||||
self._X = X if os.environ.get('TESTING', 'NS') != 'NS' else None
|
|
||||||
self._y = y
|
|
||||||
self._down = None
|
|
||||||
self._up = None
|
|
||||||
self._class = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def copy(cls, node: 'Snode') -> 'Snode':
|
|
||||||
return cls(node._clf, node._X, node._y, node._title)
|
|
||||||
|
|
||||||
def set_down(self, son):
|
|
||||||
self._down = son
|
|
||||||
|
|
||||||
def set_up(self, son):
|
|
||||||
self._up = son
|
|
||||||
|
|
||||||
def is_leaf(self,) -> bool:
|
|
||||||
return self._up is None and self._down is None
|
|
||||||
|
|
||||||
def get_down(self) -> 'Snode':
|
|
||||||
return self._down
|
|
||||||
|
|
||||||
def get_up(self) -> 'Snode':
|
|
||||||
return self._up
|
|
||||||
|
|
||||||
def make_predictor(self):
|
|
||||||
"""Compute the class of the predictor and its belief based on the subdataset of the node
|
|
||||||
only if it is a leaf
|
|
||||||
"""
|
|
||||||
if not self.is_leaf():
|
|
||||||
return
|
|
||||||
classes, card = np.unique(self._y, return_counts=True)
|
|
||||||
if len(classes) > 1:
|
|
||||||
max_card = max(card)
|
|
||||||
min_card = min(card)
|
|
||||||
try:
|
|
||||||
self._belief = max_card / (max_card + min_card)
|
|
||||||
except:
|
|
||||||
self._belief = 0.
|
|
||||||
self._class = classes[card == max_card][0]
|
|
||||||
else:
|
|
||||||
self._belief = 1
|
|
||||||
self._class = classes[0]
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
if self.is_leaf():
|
|
||||||
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
|
|
||||||
else:
|
|
||||||
return f"{self._title}"
|
|
@@ -1,101 +0,0 @@
|
|||||||
'''
|
|
||||||
__author__ = "Ricardo Montañana Gómez"
|
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
|
||||||
__license__ = "MIT"
|
|
||||||
__version__ = "0.9"
|
|
||||||
Plot 3D views of nodes in Stree
|
|
||||||
'''
|
|
||||||
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.decomposition import PCA
|
|
||||||
import trees
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from trees.Snode import Snode
|
|
||||||
from trees.Snode_graph import Snode_graph
|
|
||||||
from trees.Stree import Stree
|
|
||||||
from trees.Siterator import Siterator
|
|
||||||
|
|
||||||
|
|
||||||
class Stree_grapher(Stree):
|
|
||||||
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
|
|
||||||
make its magic
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params: dict):
|
|
||||||
self._plot_size = (8, 8)
|
|
||||||
self._tree_gr = None
|
|
||||||
# make Snode store X's
|
|
||||||
os.environ['TESTING'] = '1'
|
|
||||||
self._fitted = False
|
|
||||||
self._pca = None
|
|
||||||
super().__init__(**params)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
os.environ.pop('TESTING')
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
plt.close('all')
|
|
||||||
|
|
||||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
|
||||||
mirror = Snode_graph(node)
|
|
||||||
# clone node
|
|
||||||
mirror._class = node._class
|
|
||||||
mirror._belief = node._belief
|
|
||||||
if node.get_down() is not None:
|
|
||||||
mirror.set_down(self._copy_tree(node.get_down()))
|
|
||||||
if node.get_up() is not None:
|
|
||||||
mirror.set_up(self._copy_tree(node.get_up()))
|
|
||||||
return mirror
|
|
||||||
|
|
||||||
def fit(self, X: np.array, y: np.array) -> Stree:
|
|
||||||
"""Fit the Stree and copy the tree in a Snode_graph tree
|
|
||||||
|
|
||||||
:param X: Dataset
|
|
||||||
:type X: np.array
|
|
||||||
:param y: Labels
|
|
||||||
:type y: np.array
|
|
||||||
:return: Stree model
|
|
||||||
:rtype: Stree
|
|
||||||
"""
|
|
||||||
if X.shape[1] != 3:
|
|
||||||
self._pca = PCA(n_components=3)
|
|
||||||
X = self._pca.fit_transform(X)
|
|
||||||
res = super().fit(X, y)
|
|
||||||
self._tree_gr = self._copy_tree(self._tree)
|
|
||||||
self._fitted = True
|
|
||||||
return res
|
|
||||||
|
|
||||||
def score(self, X: np.array, y: np.array) -> float:
|
|
||||||
self._check_fitted()
|
|
||||||
if X.shape[1] != 3:
|
|
||||||
X = self._pca.transform(X)
|
|
||||||
return super().score(X, y)
|
|
||||||
|
|
||||||
def _check_fitted(self):
|
|
||||||
if not self._fitted:
|
|
||||||
raise Exception('Have to fit the grapher first!')
|
|
||||||
|
|
||||||
def save_all(self, save_folder: str = './', save_prefix: str = ''):
|
|
||||||
"""Save all the node plots in png format, each with a sequence number
|
|
||||||
|
|
||||||
:param save_folder: folder where the plots are saved, defaults to './'
|
|
||||||
:type save_folder: str, optional
|
|
||||||
"""
|
|
||||||
self._check_fitted()
|
|
||||||
seq = 1
|
|
||||||
for node in self:
|
|
||||||
node.save_hyperplane(save_folder=save_folder,
|
|
||||||
save_prefix=save_prefix, save_seq=seq)
|
|
||||||
seq += 1
|
|
||||||
|
|
||||||
def plot_all(self):
|
|
||||||
"""Plots all the nodes
|
|
||||||
"""
|
|
||||||
self._check_fitted()
|
|
||||||
for node in self:
|
|
||||||
node.plot_hyperplane()
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return Siterator(self._tree_gr)
|
|
Reference in New Issue
Block a user