diff --git a/odte/Odte.py b/odte/Odte.py index e1526b9..bd202e5 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -2,11 +2,13 @@ __author__ = "Ricardo Montañana Gómez" __copyright__ = "Copyright 2020, Ricardo Montañana Gómez" __license__ = "MIT" -Build a forest of oblique trees based on STree +Build a forest of oblique trees based on STree, admits any base classifier +as well """ from __future__ import annotations import random import sys +import json from math import factorial from typing import Union, Optional, Tuple, List, Set import numpy as np @@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin): max_features: Optional[Union[str, int, float]] = None, max_samples: Optional[Union[int, float]] = None, n_estimators: int = 100, + be_hyperparams: str = "{}", ): super().__init__( base_estimator=base_estimator, @@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin): self.random_state = random_state self.max_features = max_features self.max_samples = max_samples # size of bootstrap + self.be_hyperparams = be_hyperparams def _initialize_random(self) -> np.random.mtrand.RandomState: if self.random_state is None: @@ -110,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin): random_seed: int, boot_samples: int, max_features: int, + hyperparams: str, ) -> Tuple[BaseEstimator, Tuple[int, ...]]: clf = clone(base_estimator_) - clf.set_params(random_state=random_seed) + hyperparams_ = json.loads(hyperparams) + hyperparams_.update(dict(random_state=random_seed)) + clf.set_params(**hyperparams_) n_samples = X.shape[0] # bootstrap indices = random_box.randint(0, n_samples, boot_samples) @@ -143,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin): random_seed, boot_samples, self.max_features_, + self.be_hyperparams, ) for random_seed in range( self.random_state, self.random_state + self.n_estimators diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index d52dd95..0fc16b4 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -3,8 +3,9 @@ import unittest import os import random import warnings +import json from sklearn.exceptions import ConvergenceWarning, NotFittedError - +from sklearn.svm import SVC from odte import Odte from stree import Stree from .utils import load_dataset @@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase): nodes, leaves = tclf.nodes_leaves() self.assertAlmostEqual(9.333333333333334, leaves) self.assertAlmostEqual(17.666666666666668, nodes) + + def test_nodes_leaves_SVC(self): + tclf = Odte( + base_estimator=SVC(), + random_state=self._random_state, + n_estimators=3, + ) + X, y = load_dataset(self._random_state, n_features=16, n_samples=500) + tclf.fit(X, y) + self.assertAlmostEqual(0.0, tclf.leaves_) + self.assertAlmostEqual(0.0, tclf.nodes_) + nodes, leaves = tclf.nodes_leaves() + self.assertAlmostEqual(0.0, leaves) + self.assertAlmostEqual(0.0, nodes) + + def test_base_estimator_hyperparams(self): + data = [ + (Stree(), {"max_features": 7, "max_depth": 2}), + (SVC(), {"kernel": "linear", "cache_size": 100}), + ] + for clf, hyperparams in data: + hyperparams_ = json.dumps(hyperparams) + tclf = Odte( + base_estimator=clf, + random_state=self._random_state, + n_estimators=3, + be_hyperparams=hyperparams_, + ) + self.assertEqual(hyperparams_, tclf.be_hyperparams) + X, y = load_dataset( + self._random_state, n_features=16, n_samples=500 + ) + tclf.fit(X, y) + for estimator in tclf.estimators_: + for key, value in hyperparams.items(): + self.assertEqual(value, estimator.get_params()[key])