(#3)Add base estimator hyperparameters

This commit is contained in:
Ricardo Montañana Gómez 2021-11-24 12:34:36 +01:00
parent 74343a15e1
commit 525ee93fc3
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 48 additions and 3 deletions

View File

@ -2,11 +2,13 @@
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
__license__ = "MIT"
Build a forest of oblique trees based on STree
Build a forest of oblique trees based on STree, admits any base classifier
as well
"""
from __future__ import annotations
import random
import sys
import json
from math import factorial
from typing import Union, Optional, Tuple, List, Set
import numpy as np
@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
max_features: Optional[Union[str, int, float]] = None,
max_samples: Optional[Union[int, float]] = None,
n_estimators: int = 100,
be_hyperparams: str = "{}",
):
super().__init__(
base_estimator=base_estimator,
@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
self.random_state = random_state
self.max_features = max_features
self.max_samples = max_samples # size of bootstrap
self.be_hyperparams = be_hyperparams
def _initialize_random(self) -> np.random.mtrand.RandomState:
if self.random_state is None:
@ -110,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
random_seed: int,
boot_samples: int,
max_features: int,
hyperparams: str,
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
clf = clone(base_estimator_)
clf.set_params(random_state=random_seed)
hyperparams_ = json.loads(hyperparams)
hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_)
n_samples = X.shape[0]
# bootstrap
indices = random_box.randint(0, n_samples, boot_samples)
@ -143,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
random_seed,
boot_samples,
self.max_features_,
self.be_hyperparams,
)
for random_seed in range(
self.random_state, self.random_state + self.n_estimators

View File

@ -3,8 +3,9 @@ import unittest
import os
import random
import warnings
import json
from sklearn.exceptions import ConvergenceWarning, NotFittedError
from sklearn.svm import SVC
from odte import Odte
from stree import Stree
from .utils import load_dataset
@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase):
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(9.333333333333334, leaves)
self.assertAlmostEqual(17.666666666666668, nodes)
def test_nodes_leaves_SVC(self):
tclf = Odte(
base_estimator=SVC(),
random_state=self._random_state,
n_estimators=3,
)
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(0.0, tclf.leaves_)
self.assertAlmostEqual(0.0, tclf.nodes_)
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(0.0, leaves)
self.assertAlmostEqual(0.0, nodes)
def test_base_estimator_hyperparams(self):
data = [
(Stree(), {"max_features": 7, "max_depth": 2}),
(SVC(), {"kernel": "linear", "cache_size": 100}),
]
for clf, hyperparams in data:
hyperparams_ = json.dumps(hyperparams)
tclf = Odte(
base_estimator=clf,
random_state=self._random_state,
n_estimators=3,
be_hyperparams=hyperparams_,
)
self.assertEqual(hyperparams_, tclf.be_hyperparams)
X, y = load_dataset(
self._random_state, n_features=16, n_samples=500
)
tclf.fit(X, y)
for estimator in tclf.estimators_:
for key, value in hyperparams.items():
self.assertEqual(value, estimator.get_params()[key])