mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
Merge pull request #4 from Doctorado-ML/be_hyperparams
Add base estimator hyperparameters
This commit is contained in:
commit
2ebec2d588
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest, ubuntu-latest, windows-latest]
|
os: [macos-latest, ubuntu-latest, windows-latest]
|
||||||
python: [3.8]
|
python: [3.8, 3.9, "3.10"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
15
odte/Odte.py
15
odte/Odte.py
@ -2,11 +2,13 @@
|
|||||||
__author__ = "Ricardo Montañana Gómez"
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||||
__license__ = "MIT"
|
__license__ = "MIT"
|
||||||
Build a forest of oblique trees based on STree
|
Build a forest of oblique trees based on STree, admits any base classifier
|
||||||
|
as well
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
from math import factorial
|
from math import factorial
|
||||||
from typing import Union, Optional, Tuple, List, Set
|
from typing import Union, Optional, Tuple, List, Set
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,6 +35,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
max_features: Optional[Union[str, int, float]] = None,
|
max_features: Optional[Union[str, int, float]] = None,
|
||||||
max_samples: Optional[Union[int, float]] = None,
|
max_samples: Optional[Union[int, float]] = None,
|
||||||
n_estimators: int = 100,
|
n_estimators: int = 100,
|
||||||
|
be_hyperparams: str = "{}",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
base_estimator=base_estimator,
|
base_estimator=base_estimator,
|
||||||
@ -44,6 +47,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.max_features = max_features
|
self.max_features = max_features
|
||||||
self.max_samples = max_samples # size of bootstrap
|
self.max_samples = max_samples # size of bootstrap
|
||||||
|
self.be_hyperparams = be_hyperparams
|
||||||
|
|
||||||
def _initialize_random(self) -> np.random.mtrand.RandomState:
|
def _initialize_random(self) -> np.random.mtrand.RandomState:
|
||||||
if self.random_state is None:
|
if self.random_state is None:
|
||||||
@ -88,8 +92,9 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _compute_metrics(self) -> None:
|
def _compute_metrics(self) -> None:
|
||||||
tdepth = tnodes = tleaves = 0
|
tdepth = tnodes = tleaves = 0.0
|
||||||
for estimator in self.estimators_:
|
for estimator in self.estimators_:
|
||||||
|
if hasattr(estimator, "nodes_leaves"):
|
||||||
nodes, leaves = estimator.nodes_leaves()
|
nodes, leaves = estimator.nodes_leaves()
|
||||||
depth = estimator.depth_
|
depth = estimator.depth_
|
||||||
tdepth += depth
|
tdepth += depth
|
||||||
@ -109,9 +114,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
random_seed: int,
|
random_seed: int,
|
||||||
boot_samples: int,
|
boot_samples: int,
|
||||||
max_features: int,
|
max_features: int,
|
||||||
|
hyperparams: str,
|
||||||
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
|
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
|
||||||
clf = clone(base_estimator_)
|
clf = clone(base_estimator_)
|
||||||
clf.set_params(random_state=random_seed)
|
hyperparams_ = json.loads(hyperparams)
|
||||||
|
hyperparams_.update(dict(random_state=random_seed))
|
||||||
|
clf.set_params(**hyperparams_)
|
||||||
n_samples = X.shape[0]
|
n_samples = X.shape[0]
|
||||||
# bootstrap
|
# bootstrap
|
||||||
indices = random_box.randint(0, n_samples, boot_samples)
|
indices = random_box.randint(0, n_samples, boot_samples)
|
||||||
@ -142,6 +150,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
random_seed,
|
random_seed,
|
||||||
boot_samples,
|
boot_samples,
|
||||||
self.max_features_,
|
self.max_features_,
|
||||||
|
self.be_hyperparams,
|
||||||
)
|
)
|
||||||
for random_seed in range(
|
for random_seed in range(
|
||||||
self.random_state, self.random_state + self.n_estimators
|
self.random_state, self.random_state + self.n_estimators
|
||||||
|
@ -3,8 +3,9 @@ import unittest
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
|
import json
|
||||||
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
||||||
|
from sklearn.svm import SVC
|
||||||
from odte import Odte
|
from odte import Odte
|
||||||
from stree import Stree
|
from stree import Stree
|
||||||
from .utils import load_dataset
|
from .utils import load_dataset
|
||||||
@ -215,3 +216,39 @@ class Odte_test(unittest.TestCase):
|
|||||||
nodes, leaves = tclf.nodes_leaves()
|
nodes, leaves = tclf.nodes_leaves()
|
||||||
self.assertAlmostEqual(9.333333333333334, leaves)
|
self.assertAlmostEqual(9.333333333333334, leaves)
|
||||||
self.assertAlmostEqual(17.666666666666668, nodes)
|
self.assertAlmostEqual(17.666666666666668, nodes)
|
||||||
|
|
||||||
|
def test_nodes_leaves_SVC(self):
|
||||||
|
tclf = Odte(
|
||||||
|
base_estimator=SVC(),
|
||||||
|
random_state=self._random_state,
|
||||||
|
n_estimators=3,
|
||||||
|
)
|
||||||
|
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
|
||||||
|
tclf.fit(X, y)
|
||||||
|
self.assertAlmostEqual(0.0, tclf.leaves_)
|
||||||
|
self.assertAlmostEqual(0.0, tclf.nodes_)
|
||||||
|
nodes, leaves = tclf.nodes_leaves()
|
||||||
|
self.assertAlmostEqual(0.0, leaves)
|
||||||
|
self.assertAlmostEqual(0.0, nodes)
|
||||||
|
|
||||||
|
def test_base_estimator_hyperparams(self):
|
||||||
|
data = [
|
||||||
|
(Stree(), {"max_features": 7, "max_depth": 2}),
|
||||||
|
(SVC(), {"kernel": "linear", "cache_size": 100}),
|
||||||
|
]
|
||||||
|
for clf, hyperparams in data:
|
||||||
|
hyperparams_ = json.dumps(hyperparams)
|
||||||
|
tclf = Odte(
|
||||||
|
base_estimator=clf,
|
||||||
|
random_state=self._random_state,
|
||||||
|
n_estimators=3,
|
||||||
|
be_hyperparams=hyperparams_,
|
||||||
|
)
|
||||||
|
self.assertEqual(hyperparams_, tclf.be_hyperparams)
|
||||||
|
X, y = load_dataset(
|
||||||
|
self._random_state, n_features=16, n_samples=500
|
||||||
|
)
|
||||||
|
tclf.fit(X, y)
|
||||||
|
for estimator in tclf.estimators_:
|
||||||
|
for key, value in hyperparams.items():
|
||||||
|
self.assertEqual(value, estimator.get_params()[key])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user