mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
First Approach
This commit is contained in:
@@ -7,9 +7,9 @@ Build an oblique tree classifier based on SVM Trees
|
||||
"""
|
||||
|
||||
import os
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from typing import Optional, List, Union, Tuple
|
||||
from math import log
|
||||
from itertools import combinations
|
||||
import numpy as np
|
||||
@@ -78,10 +78,10 @@ class Snode:
|
||||
def is_leaf(self) -> bool:
|
||||
return self._up is None and self._down is None
|
||||
|
||||
def get_down(self) -> "Snode":
|
||||
def get_down(self) -> Optional["Snode"]:
|
||||
return self._down
|
||||
|
||||
def get_up(self) -> "Snode":
|
||||
def get_up(self) -> Optional["Snode"]:
|
||||
return self._up
|
||||
|
||||
def make_predictor(self):
|
||||
@@ -123,11 +123,11 @@ class Siterator:
|
||||
"""Stree preorder iterator
|
||||
"""
|
||||
|
||||
def __init__(self, tree: Snode):
|
||||
self._stack = []
|
||||
def __init__(self, tree: Optional[Snode]):
|
||||
self._stack: List[Snode] = []
|
||||
self._push(tree)
|
||||
|
||||
def _push(self, node: Snode):
|
||||
def _push(self, node: Optional[Snode]) -> None:
|
||||
if node is not None:
|
||||
self._stack.append(node)
|
||||
|
||||
@@ -150,7 +150,7 @@ class Splitter:
|
||||
min_samples_split: int = None,
|
||||
random_state=None,
|
||||
):
|
||||
self._clf = clf
|
||||
self._clf: Union[SVC, LinearSVC] = clf
|
||||
self._random_state = random_state
|
||||
if random_state is not None:
|
||||
random.seed(random_state)
|
||||
@@ -230,8 +230,8 @@ class Splitter:
|
||||
def _select_best_set(
|
||||
self, dataset: np.array, labels: np.array, features_sets: list
|
||||
) -> list:
|
||||
max_gain = 0
|
||||
selected = None
|
||||
max_gain: float = 0.0
|
||||
selected: Union[List[int], None] = None
|
||||
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||
for feature_set in features_sets:
|
||||
self._clf.fit(dataset[:, feature_set], labels)
|
||||
@@ -265,7 +265,7 @@ class Splitter:
|
||||
|
||||
def get_subspace(
|
||||
self, dataset: np.array, labels: np.array, max_features: int
|
||||
) -> list:
|
||||
) -> Tuple[np.array, np.array]:
|
||||
"""Return the best subspace to make a split
|
||||
"""
|
||||
indices = self._get_subspaces_set(dataset, labels, max_features)
|
||||
@@ -478,7 +478,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
sample_weight: np.ndarray,
|
||||
depth: int,
|
||||
title: str,
|
||||
) -> Snode:
|
||||
) -> Optional[Snode]:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
@@ -543,11 +543,13 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
|
||||
return node
|
||||
|
||||
def _build_predictor(self):
|
||||
def _build_predictor(self) -> None:
|
||||
"""Process the leaves to make them predictors
|
||||
"""
|
||||
|
||||
def run_tree(node: Snode):
|
||||
def run_tree(node: Optional[Snode]) -> None:
|
||||
if node is None:
|
||||
raise ValueError("Can't build predictors on None")
|
||||
if node.is_leaf():
|
||||
node.make_predictor()
|
||||
return
|
||||
@@ -556,7 +558,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
def _build_clf(self) -> Union[LinearSVC, SVC]:
|
||||
""" Build the correct classifier for the node
|
||||
"""
|
||||
return (
|
||||
@@ -605,7 +607,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
|
||||
def predict_class(
|
||||
xp: np.array, indices: np.array, node: Snode
|
||||
xp: np.array, indices: np.array, node: Optional[Snode]
|
||||
) -> np.array:
|
||||
if xp is None:
|
||||
return [], []
|
||||
@@ -704,7 +706,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
)
|
||||
elif self.max_features is None:
|
||||
max_features = self.n_features_
|
||||
elif isinstance(self.max_features, numbers.Integral):
|
||||
elif isinstance(self.max_features, int):
|
||||
max_features = self.max_features
|
||||
else: # float
|
||||
if self.max_features > 0.0:
|
||||
|
@@ -414,3 +414,24 @@ class Stree_test(unittest.TestCase):
|
||||
# zero weights are ok when they don't erase a class
|
||||
_ = clf.train(X, y, weights_no_zero, 1, "test")
|
||||
self.assertListEqual(weights_no_zero.tolist(), original.tolist())
|
||||
|
||||
def test_build_predictor(self):
|
||||
X, y = load_dataset(self._random_state)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
with self.assertRaises(ValueError):
|
||||
clf.tree_ = None
|
||||
clf._build_predictor()
|
||||
clf.fit(X, y)
|
||||
node = clf.tree_.get_down().get_down()
|
||||
expected_impurity = 0.04686951386893923
|
||||
expected_class = 1
|
||||
expected_belief = 0.9759887005649718
|
||||
self.assertAlmostEqual(expected_impurity, node._impurity)
|
||||
self.assertAlmostEqual(expected_belief, node._belief)
|
||||
self.assertEqual(expected_class, node._class)
|
||||
node._belief = 0.0
|
||||
node._class = None
|
||||
clf._build_predictor()
|
||||
node = clf.tree_.get_down().get_down()
|
||||
self.assertAlmostEqual(expected_belief, node._belief)
|
||||
self.assertEqual(expected_class, node._class)
|
||||
|
Reference in New Issue
Block a user