diff --git a/.coveragerc b/.coveragerc index 78a0f78..3846718 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,6 +9,4 @@ exclude_lines = raise NotImplementedError if __name__ == .__main__.: ignore_errors = True -omit = - odte/tests/* - odte/__init__.py \ No newline at end of file +omit = \ No newline at end of file diff --git a/notebooks/benchmark.ipynb b/notebooks/benchmark.ipynb index 6d7bb87..9503e8d 100644 --- a/notebooks/benchmark.ipynb +++ b/notebooks/benchmark.ipynb @@ -219,7 +219,7 @@ "outputs": [], "source": [ "# Oblique Decision Tree Ensemble\n", - "odte = Odte(random_state=random_state, n_estimators=10, max_features=None)" + "odte = Odte(random_state=random_state, n_estimators=10, max_features=\"auto\")" ] }, { diff --git a/odte/Odte.py b/odte/Odte.py index d9546a1..989f6c5 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -6,13 +6,15 @@ __version__ = "0.1" Build a forest of oblique trees based on STree """ +import random +from typing import Union +from itertools import combinations import numpy as np - from sklearn.utils import check_consistent_length from sklearn.metrics._classification import _weighted_sum, _check_targets from sklearn.utils.multiclass import check_classification_targets -from sklearn.base import BaseEstimator, ClassifierMixin -from scipy.stats import mode +from sklearn.base import clone, ClassifierMixin +from sklearn.ensemble import BaseEnsemble from sklearn.utils.validation import ( check_X_y, check_array, @@ -23,44 +25,30 @@ from sklearn.utils.validation import ( from stree import Stree -class Odte(BaseEstimator, ClassifierMixin): +class Odte(BaseEnsemble, ClassifierMixin): def __init__( self, + base_estimator=None, random_state: int = None, - C: int = 1, + max_features: Union[str, int, float] = 1.0, + max_samples: Union[int, float] = None, n_estimators: int = 100, - max_iter: int = 1000, - max_depth: int = None, - min_samples_split: int = 0, - split_criteria: str = "min_distance", - criterion: str = "gini", - tol: float = 1e-4, - gamma="scale", - degree: int = 3, - kernel: str = "linear", - max_features="auto", - max_samples=None, - splitter: str = "random", ): + base_estimator = ( + Stree(random_state=random_state) + if base_estimator is None + else base_estimator + ) + super().__init__( + base_estimator=base_estimator, n_estimators=n_estimators, + ) self.n_estimators = n_estimators self.random_state = random_state self.max_features = max_features self.max_samples = max_samples # size of bootstrap - self.estimator_params = dict( - C=C, - random_state=random_state, - min_samples_split=min_samples_split, - max_depth=max_depth, - split_criteria=split_criteria, - criterion=criterion, - kernel=kernel, - max_iter=max_iter, - tol=tol, - degree=degree, - gamma=gamma, - splitter=splitter, - max_features=max_features, - ) + + def _more_tags(self) -> dict: + return {"requires_y": True} def _initialize_random(self) -> np.random.mtrand.RandomState: if self.random_state is None: @@ -77,6 +65,12 @@ class Odte(BaseEstimator, ClassifierMixin): else: return sample_weight.copy() + def _validate_estimator(self): + """Check the estimator and set the base_estimator_ attribute.""" + super()._validate_estimator( + default=Stree(random_state=self.random_state) + ) + def fit( self, X: np.array, y: np.array, sample_weight: np.array = None ) -> "Odte": @@ -89,9 +83,16 @@ class Odte(BaseEstimator, ClassifierMixin): # the rest of parameters are checked in estimator check_classification_targets(y) X, y = check_X_y(X, y) - sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=np.float64 + ) check_classification_targets(y) # Initialize computed parameters + # Build the estimator + self.n_features_in_ = X.shape[1] + self.n_features = X.shape[1] + self.max_features_ = self._initialize_max_features() + self._validate_estimator() self.classes_, y = np.unique(y, return_inverse=True) self.n_classes_ = self.classes_.shape[0] self.estimators_ = [] @@ -107,15 +108,17 @@ class Odte(BaseEstimator, ClassifierMixin): boot_samples = self._get_bootstrap_n_samples(n_samples) for _ in range(self.n_estimators): # Build clf - clf = Stree().set_params(**self.estimator_params) + clf = clone(self.base_estimator_) + # clf.set_params(**self.estimator_params) self.estimators_.append(clf) # bootstrap indices = random_box.randint(0, n_samples, boot_samples) # update weights with the chosen samples weights_update = np.bincount(indices, minlength=n_samples) + features = self.get_subspace(X, y) current_weights = weights * weights_update # train the classifier - clf.fit(X[indices, :], y[indices], current_weights[indices]) + clf.fit(X[indices, features], y[indices], current_weights[indices]) def _get_bootstrap_n_samples(self, n_samples) -> int: if self.max_samples is None: @@ -137,15 +140,69 @@ class Odte(BaseEstimator, ClassifierMixin): {type(self.max_samples)}" ) + def _initialize_max_features(self) -> int: + if isinstance(self.max_features, str): + if self.max_features == "auto": + max_features = max(1, int(np.sqrt(self.n_features_))) + elif self.max_features == "sqrt": + max_features = max(1, int(np.sqrt(self.n_features_))) + elif self.max_features == "log2": + max_features = max(1, int(np.log2(self.n_features_))) + else: + raise ValueError( + "Invalid value for max_features. " + "Allowed string values are 'auto', " + "'sqrt' or 'log2'." + ) + elif self.max_features is None: + max_features = self.n_features_ + elif isinstance(self.max_features, int): + max_features = self.max_features + else: # float + if self.max_features > 0.0: + max_features = max( + 1, int(self.max_features * self.n_features_) + ) + else: + raise ValueError( + "Invalid value for max_features." + "Allowed float must be in range (0, 1] " + f"got ({self.max_features})" + ) + return max_features + + def _get_subspaces_set( + self, dataset: np.array, labels: np.array + ) -> np.array: + features = range(dataset.shape[1]) + features_sets = list(combinations(features, self.max_features_)) + if len(features_sets) > 1: + index = random.randint(0, len(features_sets) - 1) + return features_sets[index] + else: + return features_sets[0] + + def get_subspace(self, dataset: np.array, labels: np.array) -> list: + """Return the best subspace to build a tree + """ + indices = self._get_subspaces_set(dataset, labels) + return dataset[:, indices], indices + def predict(self, X: np.array) -> np.array: - # todo + proba = self.predict_proba(X) + return self.classes_.take((np.argmax(proba, axis=1)), axis=0) + + def predict_proba(self, X: np.array) -> np.array: check_is_fitted(self, ["estimators_"]) # Input validation X = check_array(X) - result = np.empty((X.shape[0], self.n_estimators)) - for index, tree in enumerate(self.estimators_): - result[:, index] = tree.predict(X) - return mode(result, axis=1).mode.ravel() + for tree in self.estimators_: + n_samples = X.shape[0] + result = np.zeros((n_samples, self.n_classes_)) + predictions = tree.predict(X) + for i in range(n_samples): + result[i, predictions[i]] += 1 + return result def score( self, X: np.array, y: np.array, sample_weight: np.array = None diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 6209ffa..6ee3f74 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -62,9 +62,13 @@ class Odte_test(unittest.TestCase): warnings.filterwarnings("ignore", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2] - expected = [0, 1, 1, 0] - tclf = Odte( - random_state=self._random_state, n_estimators=10, kernel="rbf" + expected = [1, 1, 1, 1] + tclf = Odte(random_state=self._random_state, n_estimators=10,) + tclf.set_params( + **dict( + base_estimator__kernel="rbf", + base_estimator__random_state=self._random_state, + ) ) computed = tclf.fit(X, y).predict(X) self.assertListEqual(expected, computed.tolist()) @@ -77,32 +81,47 @@ class Odte_test(unittest.TestCase): tclf = Odte( random_state=self._random_state, max_features=None, - kernel="linear", max_samples=0.1, ) + tclf.set_params(**dict(base_estimator__kernel="linear",)) computed = tclf.fit(X, y).predict(X) self.assertListEqual(expected[:27].tolist(), computed[:27].tolist()) def test_score(self): X, y = load_dataset(self._random_state) - expected = 0.9526666666666667 + expected = 0.948 tclf = Odte( - random_state=self._random_state, max_features=None, n_estimators=10 + random_state=self._random_state, + max_features=None, + n_estimators=10, ) computed = tclf.fit(X, y).score(X, y) self.assertAlmostEqual(expected, computed) def test_score_splitter_max_features(self): X, y = load_dataset(self._random_state, n_features=12, n_samples=150) - results = [1.0, 0.94, 0.9933333333333333, 0.9933333333333333] + results = [ + 0.9866666666666667, + 0.9866666666666667, + 0.9866666666666667, + 0.9866666666666667, + ] for max_features in ["auto", None]: for splitter in ["best", "random"]: tclf = Odte( random_state=self._random_state, - splitter=splitter, max_features=max_features, n_estimators=10, ) + tclf.set_params(**dict(base_estimator__splitter=splitter,)) expected = results.pop(0) computed = tclf.fit(X, y).score(X, y) self.assertAlmostEqual(expected, computed) + + @staticmethod + def test_is_a_sklearn_classifier(): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + from sklearn.utils.estimator_checks import check_estimator + + check_estimator(Odte())