From 74343a15e17723cc89df3d757d10dec777061452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 24 Nov 2021 10:50:19 +0100 Subject: [PATCH 1/3] Fix nodes_leaves for base_estimator --- odte/Odte.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index 890f649..e1526b9 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -88,13 +88,14 @@ class Odte(BaseEnsemble, ClassifierMixin): return self def _compute_metrics(self) -> None: - tdepth = tnodes = tleaves = 0 + tdepth = tnodes = tleaves = 0.0 for estimator in self.estimators_: - nodes, leaves = estimator.nodes_leaves() - depth = estimator.depth_ - tdepth += depth - tnodes += nodes - tleaves += leaves + if hasattr(estimator, "nodes_leaves"): + nodes, leaves = estimator.nodes_leaves() + depth = estimator.depth_ + tdepth += depth + tnodes += nodes + tleaves += leaves self.depth_ = tdepth / self.n_estimators self.leaves_ = tleaves / self.n_estimators self.nodes_ = tnodes / self.n_estimators From 525ee93fc35c197967e77ed20b83c2fde55e6df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 24 Nov 2021 12:34:36 +0100 Subject: [PATCH 2/3] (#3)Add base estimator hyperparameters --- odte/Odte.py | 12 ++++++++++-- odte/tests/Odte_tests.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index e1526b9..bd202e5 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -2,11 +2,13 @@ __author__ = "Ricardo Montañana Gómez" __copyright__ = "Copyright 2020, Ricardo Montañana Gómez" __license__ = "MIT" -Build a forest of oblique trees based on STree +Build a forest of oblique trees based on STree, admits any base classifier +as well """ from __future__ import annotations import random import sys +import json from math import factorial from typing import Union, Optional, Tuple, List, Set import numpy as np @@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin): max_features: Optional[Union[str, int, float]] = None, max_samples: Optional[Union[int, float]] = None, n_estimators: int = 100, + be_hyperparams: str = "{}", ): super().__init__( base_estimator=base_estimator, @@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin): self.random_state = random_state self.max_features = max_features self.max_samples = max_samples # size of bootstrap + self.be_hyperparams = be_hyperparams def _initialize_random(self) -> np.random.mtrand.RandomState: if self.random_state is None: @@ -110,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin): random_seed: int, boot_samples: int, max_features: int, + hyperparams: str, ) -> Tuple[BaseEstimator, Tuple[int, ...]]: clf = clone(base_estimator_) - clf.set_params(random_state=random_seed) + hyperparams_ = json.loads(hyperparams) + hyperparams_.update(dict(random_state=random_seed)) + clf.set_params(**hyperparams_) n_samples = X.shape[0] # bootstrap indices = random_box.randint(0, n_samples, boot_samples) @@ -143,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin): random_seed, boot_samples, self.max_features_, + self.be_hyperparams, ) for random_seed in range( self.random_state, self.random_state + self.n_estimators diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index d52dd95..0fc16b4 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -3,8 +3,9 @@ import unittest import os import random import warnings +import json from sklearn.exceptions import ConvergenceWarning, NotFittedError - +from sklearn.svm import SVC from odte import Odte from stree import Stree from .utils import load_dataset @@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase): nodes, leaves = tclf.nodes_leaves() self.assertAlmostEqual(9.333333333333334, leaves) self.assertAlmostEqual(17.666666666666668, nodes) + + def test_nodes_leaves_SVC(self): + tclf = Odte( + base_estimator=SVC(), + random_state=self._random_state, + n_estimators=3, + ) + X, y = load_dataset(self._random_state, n_features=16, n_samples=500) + tclf.fit(X, y) + self.assertAlmostEqual(0.0, tclf.leaves_) + self.assertAlmostEqual(0.0, tclf.nodes_) + nodes, leaves = tclf.nodes_leaves() + self.assertAlmostEqual(0.0, leaves) + self.assertAlmostEqual(0.0, nodes) + + def test_base_estimator_hyperparams(self): + data = [ + (Stree(), {"max_features": 7, "max_depth": 2}), + (SVC(), {"kernel": "linear", "cache_size": 100}), + ] + for clf, hyperparams in data: + hyperparams_ = json.dumps(hyperparams) + tclf = Odte( + base_estimator=clf, + random_state=self._random_state, + n_estimators=3, + be_hyperparams=hyperparams_, + ) + self.assertEqual(hyperparams_, tclf.be_hyperparams) + X, y = load_dataset( + self._random_state, n_features=16, n_samples=500 + ) + tclf.fit(X, y) + for estimator in tclf.estimators_: + for key, value in hyperparams.items(): + self.assertEqual(value, estimator.get_params()[key]) From 67424e06be109b258882e413e107e388b7dc1586 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Wed, 24 Nov 2021 12:54:25 +0100 Subject: [PATCH 3/3] Add python versions 3.9 & 3.10 to github actions --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6c1a964..1b4e627 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: os: [macos-latest, ubuntu-latest, windows-latest] - python: [3.8] + python: [3.8, 3.9, "3.10"] steps: - uses: actions/checkout@v2