First Approach

This commit is contained in:
2020-06-28 02:46:20 +02:00
parent be552fdd6c
commit fa001f97a4
2 changed files with 39 additions and 16 deletions

View File

@@ -7,9 +7,9 @@ Build an oblique tree classifier based on SVM Trees
"""
import os
import numbers
import random
import warnings
from typing import Optional, List, Union, Tuple
from math import log
from itertools import combinations
import numpy as np
@@ -78,10 +78,10 @@ class Snode:
def is_leaf(self) -> bool:
return self._up is None and self._down is None
def get_down(self) -> "Snode":
def get_down(self) -> Optional["Snode"]:
return self._down
def get_up(self) -> "Snode":
def get_up(self) -> Optional["Snode"]:
return self._up
def make_predictor(self):
@@ -123,11 +123,11 @@ class Siterator:
"""Stree preorder iterator
"""
def __init__(self, tree: Snode):
self._stack = []
def __init__(self, tree: Optional[Snode]):
self._stack: List[Snode] = []
self._push(tree)
def _push(self, node: Snode):
def _push(self, node: Optional[Snode]) -> None:
if node is not None:
self._stack.append(node)
@@ -150,7 +150,7 @@ class Splitter:
min_samples_split: int = None,
random_state=None,
):
self._clf = clf
self._clf: Union[SVC, LinearSVC] = clf
self._random_state = random_state
if random_state is not None:
random.seed(random_state)
@@ -230,8 +230,8 @@ class Splitter:
def _select_best_set(
self, dataset: np.array, labels: np.array, features_sets: list
) -> list:
max_gain = 0
selected = None
max_gain: float = 0.0
selected: Union[List[int], None] = None
warnings.filterwarnings("ignore", category=ConvergenceWarning)
for feature_set in features_sets:
self._clf.fit(dataset[:, feature_set], labels)
@@ -265,7 +265,7 @@ class Splitter:
def get_subspace(
self, dataset: np.array, labels: np.array, max_features: int
) -> list:
) -> Tuple[np.array, np.array]:
"""Return the best subspace to make a split
"""
indices = self._get_subspaces_set(dataset, labels, max_features)
@@ -478,7 +478,7 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
) -> Optional[Snode]:
"""Recursive function to split the original dataset into predictor
nodes (leaves)
@@ -543,11 +543,13 @@ class Stree(BaseEstimator, ClassifierMixin):
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return node
def _build_predictor(self):
def _build_predictor(self) -> None:
"""Process the leaves to make them predictors
"""
def run_tree(node: Snode):
def run_tree(node: Optional[Snode]) -> None:
if node is None:
raise ValueError("Can't build predictors on None")
if node.is_leaf():
node.make_predictor()
return
@@ -556,7 +558,7 @@ class Stree(BaseEstimator, ClassifierMixin):
run_tree(self.tree_)
def _build_clf(self):
def _build_clf(self) -> Union[LinearSVC, SVC]:
""" Build the correct classifier for the node
"""
return (
@@ -605,7 +607,7 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
def predict_class(
xp: np.array, indices: np.array, node: Snode
xp: np.array, indices: np.array, node: Optional[Snode]
) -> np.array:
if xp is None:
return [], []
@@ -704,7 +706,7 @@ class Stree(BaseEstimator, ClassifierMixin):
)
elif self.max_features is None:
max_features = self.n_features_
elif isinstance(self.max_features, numbers.Integral):
elif isinstance(self.max_features, int):
max_features = self.max_features
else: # float
if self.max_features > 0.0:

View File

@@ -414,3 +414,24 @@ class Stree_test(unittest.TestCase):
# zero weights are ok when they don't erase a class
_ = clf.train(X, y, weights_no_zero, 1, "test")
self.assertListEqual(weights_no_zero.tolist(), original.tolist())
def test_build_predictor(self):
X, y = load_dataset(self._random_state)
clf = Stree(random_state=self._random_state)
with self.assertRaises(ValueError):
clf.tree_ = None
clf._build_predictor()
clf.fit(X, y)
node = clf.tree_.get_down().get_down()
expected_impurity = 0.04686951386893923
expected_class = 1
expected_belief = 0.9759887005649718
self.assertAlmostEqual(expected_impurity, node._impurity)
self.assertAlmostEqual(expected_belief, node._belief)
self.assertEqual(expected_class, node._class)
node._belief = 0.0
node._class = None
clf._build_predictor()
node = clf.tree_.get_down().get_down()
self.assertAlmostEqual(expected_belief, node._belief)
self.assertEqual(expected_class, node._class)