mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
Implement predict and score methods & tests
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "1.0"
|
||||
Create a oblique tree classifier based on SVM Trees
|
||||
__version__ = "0.9"
|
||||
Build an oblique tree classifier based on SVM Trees
|
||||
Uses LinearSVC
|
||||
'''
|
||||
|
||||
@@ -25,6 +25,7 @@ class Stree:
|
||||
self._tree = None
|
||||
self.__folder = 'data/'
|
||||
self.__use_predictions = use_predictions
|
||||
self.__trained = False
|
||||
|
||||
def _split_data(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray) -> list:
|
||||
if self.__use_predictions:
|
||||
@@ -46,10 +47,11 @@ class Stree:
|
||||
|
||||
def fit(self, X: np.ndarray, y: np.ndarray, title: str = 'root') -> 'Stree':
|
||||
self._tree = self.train(X, y, title)
|
||||
self._predictor()
|
||||
self._build_predictor()
|
||||
self.__trained = True
|
||||
return self
|
||||
|
||||
def _predictor(self):
|
||||
def _build_predictor(self):
|
||||
"""Process the leaves to make them predictors
|
||||
"""
|
||||
def run_tree(node: Snode):
|
||||
@@ -79,6 +81,28 @@ class Stree:
|
||||
str(np.unique(y_d, return_counts=True))))
|
||||
return tree
|
||||
|
||||
def predict(self, X: np.array) -> np.array:
|
||||
def predict_class(xp: np.array, tree: Snode) -> np.array:
|
||||
if tree.is_leaf():
|
||||
return tree._class
|
||||
coef = tree._vector[0, :].reshape(-1, xp.shape[1])
|
||||
if xp.dot(coef.T) + tree._interceptor[0] > 0:
|
||||
return predict_class(xp, tree.get_down())
|
||||
return predict_class(xp, tree.get_up())
|
||||
y = np.array([], dtype=int)
|
||||
for xp in X:
|
||||
y = np.append(y, predict_class(xp.reshape(-1, X.shape[1]), self._tree))
|
||||
return y
|
||||
|
||||
def score(self, X: np.array, y: np.array, print_out=True) -> float:
|
||||
self.fit(X, y)
|
||||
yp = self.predict(X)
|
||||
right = (yp == y).astype(int)
|
||||
accuracy = sum(right) / len(y)
|
||||
if print_out:
|
||||
print(f"Accuracy: {accuracy:.6f}")
|
||||
return accuracy
|
||||
|
||||
def __str__(self):
|
||||
def print_tree(tree: Snode) -> str:
|
||||
output = str(tree)
|
||||
|
Reference in New Issue
Block a user