(#3)Add base estimator hyperparameters

This commit is contained in:
Ricardo Montañana Gómez 2021-11-24 12:34:36 +01:00
parent 74343a15e1
commit 525ee93fc3
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 48 additions and 3 deletions

View File

@ -2,11 +2,13 @@
__author__ = "Ricardo Montañana Gómez" __author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez" __copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
__license__ = "MIT" __license__ = "MIT"
Build a forest of oblique trees based on STree Build a forest of oblique trees based on STree, admits any base classifier
as well
""" """
from __future__ import annotations from __future__ import annotations
import random import random
import sys import sys
import json
from math import factorial from math import factorial
from typing import Union, Optional, Tuple, List, Set from typing import Union, Optional, Tuple, List, Set
import numpy as np import numpy as np
@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
max_features: Optional[Union[str, int, float]] = None, max_features: Optional[Union[str, int, float]] = None,
max_samples: Optional[Union[int, float]] = None, max_samples: Optional[Union[int, float]] = None,
n_estimators: int = 100, n_estimators: int = 100,
be_hyperparams: str = "{}",
): ):
super().__init__( super().__init__(
base_estimator=base_estimator, base_estimator=base_estimator,
@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
self.random_state = random_state self.random_state = random_state
self.max_features = max_features self.max_features = max_features
self.max_samples = max_samples # size of bootstrap self.max_samples = max_samples # size of bootstrap
self.be_hyperparams = be_hyperparams
def _initialize_random(self) -> np.random.mtrand.RandomState: def _initialize_random(self) -> np.random.mtrand.RandomState:
if self.random_state is None: if self.random_state is None:
@ -110,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
random_seed: int, random_seed: int,
boot_samples: int, boot_samples: int,
max_features: int, max_features: int,
hyperparams: str,
) -> Tuple[BaseEstimator, Tuple[int, ...]]: ) -> Tuple[BaseEstimator, Tuple[int, ...]]:
clf = clone(base_estimator_) clf = clone(base_estimator_)
clf.set_params(random_state=random_seed) hyperparams_ = json.loads(hyperparams)
hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_)
n_samples = X.shape[0] n_samples = X.shape[0]
# bootstrap # bootstrap
indices = random_box.randint(0, n_samples, boot_samples) indices = random_box.randint(0, n_samples, boot_samples)
@ -143,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
random_seed, random_seed,
boot_samples, boot_samples,
self.max_features_, self.max_features_,
self.be_hyperparams,
) )
for random_seed in range( for random_seed in range(
self.random_state, self.random_state + self.n_estimators self.random_state, self.random_state + self.n_estimators

View File

@ -3,8 +3,9 @@ import unittest
import os import os
import random import random
import warnings import warnings
import json
from sklearn.exceptions import ConvergenceWarning, NotFittedError from sklearn.exceptions import ConvergenceWarning, NotFittedError
from sklearn.svm import SVC
from odte import Odte from odte import Odte
from stree import Stree from stree import Stree
from .utils import load_dataset from .utils import load_dataset
@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase):
nodes, leaves = tclf.nodes_leaves() nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(9.333333333333334, leaves) self.assertAlmostEqual(9.333333333333334, leaves)
self.assertAlmostEqual(17.666666666666668, nodes) self.assertAlmostEqual(17.666666666666668, nodes)
def test_nodes_leaves_SVC(self):
tclf = Odte(
base_estimator=SVC(),
random_state=self._random_state,
n_estimators=3,
)
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(0.0, tclf.leaves_)
self.assertAlmostEqual(0.0, tclf.nodes_)
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(0.0, leaves)
self.assertAlmostEqual(0.0, nodes)
def test_base_estimator_hyperparams(self):
data = [
(Stree(), {"max_features": 7, "max_depth": 2}),
(SVC(), {"kernel": "linear", "cache_size": 100}),
]
for clf, hyperparams in data:
hyperparams_ = json.dumps(hyperparams)
tclf = Odte(
base_estimator=clf,
random_state=self._random_state,
n_estimators=3,
be_hyperparams=hyperparams_,
)
self.assertEqual(hyperparams_, tclf.be_hyperparams)
X, y = load_dataset(
self._random_state, n_features=16, n_samples=500
)
tclf.fit(X, y)
for estimator in tclf.estimators_:
for key, value in hyperparams.items():
self.assertEqual(value, estimator.get_params()[key])