mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 00:02:30 +00:00
(#3)Add base estimator hyperparameters
This commit is contained in:
parent
74343a15e1
commit
525ee93fc3
12
odte/Odte.py
12
odte/Odte.py
@ -2,11 +2,13 @@
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
Build a forest of oblique trees based on STree
|
||||
Build a forest of oblique trees based on STree, admits any base classifier
|
||||
as well
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import random
|
||||
import sys
|
||||
import json
|
||||
from math import factorial
|
||||
from typing import Union, Optional, Tuple, List, Set
|
||||
import numpy as np
|
||||
@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
max_features: Optional[Union[str, int, float]] = None,
|
||||
max_samples: Optional[Union[int, float]] = None,
|
||||
n_estimators: int = 100,
|
||||
be_hyperparams: str = "{}",
|
||||
):
|
||||
super().__init__(
|
||||
base_estimator=base_estimator,
|
||||
@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
self.random_state = random_state
|
||||
self.max_features = max_features
|
||||
self.max_samples = max_samples # size of bootstrap
|
||||
self.be_hyperparams = be_hyperparams
|
||||
|
||||
def _initialize_random(self) -> np.random.mtrand.RandomState:
|
||||
if self.random_state is None:
|
||||
@ -110,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
random_seed: int,
|
||||
boot_samples: int,
|
||||
max_features: int,
|
||||
hyperparams: str,
|
||||
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
|
||||
clf = clone(base_estimator_)
|
||||
clf.set_params(random_state=random_seed)
|
||||
hyperparams_ = json.loads(hyperparams)
|
||||
hyperparams_.update(dict(random_state=random_seed))
|
||||
clf.set_params(**hyperparams_)
|
||||
n_samples = X.shape[0]
|
||||
# bootstrap
|
||||
indices = random_box.randint(0, n_samples, boot_samples)
|
||||
@ -143,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
random_seed,
|
||||
boot_samples,
|
||||
self.max_features_,
|
||||
self.be_hyperparams,
|
||||
)
|
||||
for random_seed in range(
|
||||
self.random_state, self.random_state + self.n_estimators
|
||||
|
@ -3,8 +3,9 @@ import unittest
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
import json
|
||||
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
||||
|
||||
from sklearn.svm import SVC
|
||||
from odte import Odte
|
||||
from stree import Stree
|
||||
from .utils import load_dataset
|
||||
@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase):
|
||||
nodes, leaves = tclf.nodes_leaves()
|
||||
self.assertAlmostEqual(9.333333333333334, leaves)
|
||||
self.assertAlmostEqual(17.666666666666668, nodes)
|
||||
|
||||
def test_nodes_leaves_SVC(self):
|
||||
tclf = Odte(
|
||||
base_estimator=SVC(),
|
||||
random_state=self._random_state,
|
||||
n_estimators=3,
|
||||
)
|
||||
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
|
||||
tclf.fit(X, y)
|
||||
self.assertAlmostEqual(0.0, tclf.leaves_)
|
||||
self.assertAlmostEqual(0.0, tclf.nodes_)
|
||||
nodes, leaves = tclf.nodes_leaves()
|
||||
self.assertAlmostEqual(0.0, leaves)
|
||||
self.assertAlmostEqual(0.0, nodes)
|
||||
|
||||
def test_base_estimator_hyperparams(self):
|
||||
data = [
|
||||
(Stree(), {"max_features": 7, "max_depth": 2}),
|
||||
(SVC(), {"kernel": "linear", "cache_size": 100}),
|
||||
]
|
||||
for clf, hyperparams in data:
|
||||
hyperparams_ = json.dumps(hyperparams)
|
||||
tclf = Odte(
|
||||
base_estimator=clf,
|
||||
random_state=self._random_state,
|
||||
n_estimators=3,
|
||||
be_hyperparams=hyperparams_,
|
||||
)
|
||||
self.assertEqual(hyperparams_, tclf.be_hyperparams)
|
||||
X, y = load_dataset(
|
||||
self._random_state, n_features=16, n_samples=500
|
||||
)
|
||||
tclf.fit(X, y)
|
||||
for estimator in tclf.estimators_:
|
||||
for key, value in hyperparams.items():
|
||||
self.assertEqual(value, estimator.get_params()[key])
|
||||
|
Loading…
x
Reference in New Issue
Block a user