mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Implement predict and score methods & tests
This commit is contained in:
1
main.py
1
main.py
@@ -9,3 +9,4 @@ model = Stree(random_state=random_state)
|
|||||||
model.fit(X, y)
|
model.fit(X, y)
|
||||||
print(model)
|
print(model)
|
||||||
model.save_sub_datasets()
|
model.save_sub_datasets()
|
||||||
|
print(f"Prediciting [{y[0]}] we have {model.predict(X[0, :].reshape(-1, X.shape[1]))}")
|
||||||
|
206
test.ipynb
206
test.ipynb
@@ -9,6 +9,7 @@
|
|||||||
"import numpy as np \n",
|
"import numpy as np \n",
|
||||||
"from sklearn.svm import LinearSVC\n",
|
"from sklearn.svm import LinearSVC\n",
|
||||||
"from sklearn.datasets import make_classification\n",
|
"from sklearn.datasets import make_classification\n",
|
||||||
|
"from trees.Stree import Stree\n",
|
||||||
"\n",
|
"\n",
|
||||||
"random_state = 1\n",
|
"random_state = 1\n",
|
||||||
"X, y = make_classification(n_samples=1500, n_features=3, n_informative=3, \n",
|
"X, y = make_classification(n_samples=1500, n_features=3, n_informative=3, \n",
|
||||||
@@ -18,215 +19,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": "data/dataset1.csv - root\ndata/dataset2.csv - root - Down\ndata/dataset3.csv - root - Down - Down, classes=[0 1], items<0>=17, items<1>=691, <couldn't go any further> LEAF accuracy=0.98\ndata/dataset4.csv - root - Down - Up\ndata/dataset5.csv - root - Down - Up - Down, classes=[0 1], items<0>=1, items<1>=3, <couldn't go any further> LEAF accuracy=0.75\ndata/dataset6.csv - root - Down - Up - Up, class=[0], items=7, rest=0, <pure> LEAF accuracy=1.00\ndata/dataset3.csv - root - Up, classes=[0 1], items<0>=725, items<1>=56, <couldn't go any further> LEAF accuracy=0.93\n"
|
"text": "Accuracy: 0.950667\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"!cat data/catalog.txt"
|
"clf = Stree(random_state=random_state, use_predictions=False)\n",
|
||||||
|
"clf.fit(X, y)\n",
|
||||||
|
"accuracy = clf.score(X, y)"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def readsub(name):\n",
|
|
||||||
" data = np.genfromtxt(name, delimiter=',')\n",
|
|
||||||
" data = np.array(data)\n",
|
|
||||||
" py = data[:, data.shape[1] - 1]\n",
|
|
||||||
" px = np.delete(data, data.shape[1] - 1, axis=1)\n",
|
|
||||||
" return px, py\n",
|
|
||||||
"def localiza(X, px):\n",
|
|
||||||
" enc = False\n",
|
|
||||||
" for i in range(X.shape[0]):\n",
|
|
||||||
" if all(X[i, :] == px):\n",
|
|
||||||
" enc = True\n",
|
|
||||||
" print(f\" i={i} - X[{i}, :]={X[i, :]} - px={px} - y[{i}]={y[i]}\")\n",
|
|
||||||
" print(\"Encontrado:\", enc)\n",
|
|
||||||
" "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"px, py = readsub('data/dataset5.csv')\n",
|
|
||||||
"model = LinearSVC(random_state=1, max_iter=1000)\n",
|
|
||||||
"model.fit(px,py)\n",
|
|
||||||
"yp = model.predict(px)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": "[1. 1. 1. 1.]\n[1. 1. 0. 1.]\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(yp)\n",
|
|
||||||
"print(py)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": "i=1132 - X[1132, :]=[-0.41453617 -0.38206564 0.54849331] - px=[-0.41453617 -0.38206564 0.54849331] - y[1132]=0\nEncontrado: True\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"localiza(X, px[2, :])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": "[LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0), LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0), LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0), LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0), LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n verbose=0)]\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.svm import LinearSVC\n",
|
|
||||||
"\n",
|
|
||||||
"data = []\n",
|
|
||||||
"for i in range(5):\n",
|
|
||||||
" model = LinearSVC()\n",
|
|
||||||
" data.append(model)\n",
|
|
||||||
"\n",
|
|
||||||
"print(data)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": "4\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"output_type": "error",
|
|
||||||
"ename": "NameError",
|
|
||||||
"evalue": "name 'gato' is not defined",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[0;32m<ipython-input-3-04351d05a6f0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpato\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgato\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
||||||
"\u001b[0;31mNameError\u001b[0m: name 'gato' is not defined"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"def pato(k):\n",
|
|
||||||
" def gato(m, u):\n",
|
|
||||||
" return m * u\n",
|
|
||||||
" return gato(k, k)\n",
|
|
||||||
"\n",
|
|
||||||
"print(pato(2))\n",
|
|
||||||
"print(gato(3,4))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": "7\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"try:\n",
|
|
||||||
" a= max(5,3)/min(0,1)\n",
|
|
||||||
"except:\n",
|
|
||||||
" a=7\n",
|
|
||||||
"print(a)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "error",
|
|
||||||
"ename": "SyntaxError",
|
|
||||||
"evalue": "invalid syntax (<ipython-input-6-65e24c447a24>, line 1)",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-6-65e24c447a24>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m max([2 5])\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"max([2 5])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"y=[1,2,4,5,5,5,5,3,3,3,2,]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"a,b = np.unique(y, return_counts=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "execute_result",
|
|
||||||
"data": {
|
|
||||||
"text/plain": "11"
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"execution_count": 12
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"np.count_nonzero(y)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@@ -11,9 +11,9 @@ class Snode_test(unittest.TestCase):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._model = Stree(random_state=self._random_state,
|
self._clf = Stree(random_state=self._random_state,
|
||||||
use_predictions=True)
|
use_predictions=True)
|
||||||
self._model.fit(*self._get_Xy())
|
self._clf.fit(*self._get_Xy())
|
||||||
super(Snode_test, self).__init__(*args, **kwargs)
|
super(Snode_test, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _get_Xy(self):
|
def _get_Xy(self):
|
||||||
@@ -42,4 +42,4 @@ class Snode_test(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
check_leave(node.get_down())
|
check_leave(node.get_down())
|
||||||
check_leave(node.get_up())
|
check_leave(node.get_up())
|
||||||
check_leave(self._model._tree)
|
check_leave(self._clf._tree)
|
||||||
|
@@ -11,9 +11,9 @@ class Stree_test(unittest.TestCase):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._model = Stree(random_state=self._random_state,
|
self._clf = Stree(random_state=self._random_state,
|
||||||
use_predictions=True)
|
use_predictions=False)
|
||||||
self._model.fit(*self._get_Xy())
|
self._clf.fit(*self._get_Xy())
|
||||||
super(Stree_test, self).__init__(*args, **kwargs)
|
super(Stree_test, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _get_Xy(self):
|
def _get_Xy(self):
|
||||||
@@ -25,7 +25,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
def _check_tree(self, node: Snode):
|
def _check_tree(self, node: Snode):
|
||||||
if node.is_leaf():
|
if node.is_leaf():
|
||||||
return
|
return
|
||||||
y_prediction = node._model.predict(node._X)
|
y_prediction = node._clf.predict(node._X)
|
||||||
y_down = node.get_down()._y
|
y_down = node.get_down()._y
|
||||||
y_up = node.get_up()._y
|
y_up = node.get_up()._y
|
||||||
# Is a correct partition in terms of cadinality?
|
# Is a correct partition in terms of cadinality?
|
||||||
@@ -55,7 +55,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
def test_build_tree(self):
|
def test_build_tree(self):
|
||||||
"""Check if the tree is built the same way as predictions of models
|
"""Check if the tree is built the same way as predictions of models
|
||||||
"""
|
"""
|
||||||
self._check_tree(self._model._tree)
|
self._check_tree(self._clf._tree)
|
||||||
|
|
||||||
def _get_file_data(self, file_name: str) -> tuple:
|
def _get_file_data(self, file_name: str) -> tuple:
|
||||||
"""Return X, y from data, y is the last column in array
|
"""Return X, y from data, y is the last column in array
|
||||||
@@ -94,14 +94,32 @@ class Stree_test(unittest.TestCase):
|
|||||||
def test_subdatasets(self):
|
def test_subdatasets(self):
|
||||||
"""Check if the subdatasets files have the same predictions as the tree itself
|
"""Check if the subdatasets files have the same predictions as the tree itself
|
||||||
"""
|
"""
|
||||||
model = self._model._tree._model
|
model = self._clf._tree._clf
|
||||||
X, y = self._get_Xy()
|
X, y = self._get_Xy()
|
||||||
model.fit(X, y)
|
model.fit(X, y)
|
||||||
self._model.save_sub_datasets()
|
self._clf.save_sub_datasets()
|
||||||
with open(self._model.get_catalog_name()) as cat_file:
|
with open(self._clf.get_catalog_name()) as cat_file:
|
||||||
catalog = csv.reader(cat_file, delimiter=',')
|
catalog = csv.reader(cat_file, delimiter=',')
|
||||||
for row in catalog:
|
for row in catalog:
|
||||||
X, y = self._get_Xy()
|
X, y = self._get_Xy()
|
||||||
x_file, y_file = self._get_file_data(row[0])
|
x_file, y_file = self._get_file_data(row[0])
|
||||||
y_original = np.array(self._find_out(x_file, X, y), dtype=int)
|
y_original = np.array(self._find_out(x_file, X, y), dtype=int)
|
||||||
self.assertTrue(np.array_equal(y_file, y_original))
|
self.assertTrue(np.array_equal(y_file, y_original))
|
||||||
|
|
||||||
|
def test_single_prediction(self):
|
||||||
|
X, y = self._get_Xy()
|
||||||
|
yp = self._clf.predict((X[0, :].reshape(-1, X.shape[1])))
|
||||||
|
self.assertEqual(yp[0], y[0])
|
||||||
|
|
||||||
|
def test_multiple_prediction(self):
|
||||||
|
X, y = self._get_Xy()
|
||||||
|
yp = self._clf.predict(X[:23, :])
|
||||||
|
self.assertListEqual(y[:23].tolist(), yp.tolist())
|
||||||
|
|
||||||
|
def test_score(self):
|
||||||
|
X, y = self._get_Xy()
|
||||||
|
accuracy_score = self._clf.score(X, y, print_out=False)
|
||||||
|
yp = self._clf.predict(X)
|
||||||
|
right = (yp == y).astype(int)
|
||||||
|
accuracy_computed = sum(right) / len(y)
|
||||||
|
self.assertEqual(accuracy_score, accuracy_computed)
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
__author__ = "Ricardo Montañana Gómez"
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
__version__ = "1.0"
|
__version__ = "0.9"
|
||||||
Node of the Stree (binary tree)
|
Node of the Stree (binary tree)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@@ -11,10 +11,10 @@ from sklearn.svm import LinearSVC
|
|||||||
|
|
||||||
|
|
||||||
class Snode:
|
class Snode:
|
||||||
def __init__(self, model: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
||||||
self._model = model
|
self._clf = clf
|
||||||
self._vector = None if model is None else model.coef_
|
self._vector = None if clf is None else clf.coef_
|
||||||
self._interceptor = 0 if model is None else model.intercept_
|
self._interceptor = 0 if clf is None else clf.intercept_
|
||||||
self._title = title
|
self._title = title
|
||||||
self._belief = 0 # belief of the prediction in a leaf node based on samples
|
self._belief = 0 # belief of the prediction in a leaf node based on samples
|
||||||
self._X = X
|
self._X = X
|
||||||
@@ -60,6 +60,6 @@ class Snode:
|
|||||||
num = max(num, self._y[self._y == i].shape[0])
|
num = max(num, self._y[self._y == i].shape[0])
|
||||||
den = self._y.shape[0]
|
den = self._y.shape[0]
|
||||||
accuracy = num / den if den != 0 else 1
|
accuracy = num / den if den != 0 else 1
|
||||||
return f"{self._title} LEAF accuracy={accuracy:.2f}\n"
|
return f"{self._title} LEAF accuracy={accuracy:.2f}, belief={self._belief:.2f} class={self._class}\n"
|
||||||
else:
|
else:
|
||||||
return f"{self._title}\n"
|
return f"{self._title}\n"
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
__author__ = "Ricardo Montañana Gómez"
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
__version__ = "1.0"
|
__version__ = "0.9"
|
||||||
Create a oblique tree classifier based on SVM Trees
|
Build an oblique tree classifier based on SVM Trees
|
||||||
Uses LinearSVC
|
Uses LinearSVC
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@@ -25,6 +25,7 @@ class Stree:
|
|||||||
self._tree = None
|
self._tree = None
|
||||||
self.__folder = 'data/'
|
self.__folder = 'data/'
|
||||||
self.__use_predictions = use_predictions
|
self.__use_predictions = use_predictions
|
||||||
|
self.__trained = False
|
||||||
|
|
||||||
def _split_data(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray) -> list:
|
def _split_data(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray) -> list:
|
||||||
if self.__use_predictions:
|
if self.__use_predictions:
|
||||||
@@ -46,10 +47,11 @@ class Stree:
|
|||||||
|
|
||||||
def fit(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> 'Stree':
|
def fit(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> 'Stree':
|
||||||
self._tree = self.train(X, y, title)
|
self._tree = self.train(X, y, title)
|
||||||
self._predictor()
|
self._build_predictor()
|
||||||
|
self.__trained = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _predictor(self):
|
def _build_predictor(self):
|
||||||
"""Process the leaves to make them predictors
|
"""Process the leaves to make them predictors
|
||||||
"""
|
"""
|
||||||
def run_tree(node: Snode):
|
def run_tree(node: Snode):
|
||||||
@@ -79,6 +81,28 @@ class Stree:
|
|||||||
str(np.unique(y_d, return_counts=True))))
|
str(np.unique(y_d, return_counts=True))))
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
def predict(self, X: np.array) -> np.array:
|
||||||
|
def predict_class(xp: np.array, tree: Snode) -> np.array:
|
||||||
|
if tree.is_leaf():
|
||||||
|
return tree._class
|
||||||
|
coef = tree._vector[0, :].reshape(-1, xp.shape[1])
|
||||||
|
if xp.dot(coef.T) + tree._interceptor[0] > 0:
|
||||||
|
return predict_class(xp, tree.get_down())
|
||||||
|
return predict_class(xp, tree.get_up())
|
||||||
|
y = np.array([], dtype=int)
|
||||||
|
for xp in X:
|
||||||
|
y = np.append(y, predict_class(xp.reshape(-1, X.shape[1]), self._tree))
|
||||||
|
return y
|
||||||
|
|
||||||
|
def score(self, X: np.array, y: np.array, print_out=True) -> float:
|
||||||
|
self.fit(X, y)
|
||||||
|
yp = self.predict(X)
|
||||||
|
right = (yp == y).astype(int)
|
||||||
|
accuracy = sum(right) / len(y)
|
||||||
|
if print_out:
|
||||||
|
print(f"Accuracy: {accuracy:.6f}")
|
||||||
|
return accuracy
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
def print_tree(tree: Snode) -> str:
|
def print_tree(tree: Snode) -> str:
|
||||||
output = str(tree)
|
output = str(tree)
|
||||||
|
Reference in New Issue
Block a user