mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
Working predict and score (basic)
This commit is contained in:
parent
ef6b5b08d5
commit
50e25bc372
@ -1,8 +1,6 @@
|
||||
[](https://app.codeship.com/projects/399830)
|
||||
|
||||
[](https://codecov.io/gh/Doctorado-ML/odte)
|
||||
|
||||
[](https://www.codacy.com/gh/Doctorado-ML/Odte?utm_source=github.com&utm_medium=referral&utm_content=Doctorado-ML/Odte&utm_campaign=Badge_Grade)
|
||||
|
||||
# odte
|
||||
# Odte
|
||||
Oblique Decision Tree Ensemble
|
||||
|
16
odte/Odte.py
16
odte/Odte.py
@ -12,6 +12,7 @@ from sklearn.utils import check_consistent_length
|
||||
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from scipy.stats import mode
|
||||
from sklearn.utils.validation import (
|
||||
check_X_y,
|
||||
check_array,
|
||||
@ -76,9 +77,9 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||
) -> "Odte":
|
||||
# Check parameters are Ok.
|
||||
if self.n_estimators < 10:
|
||||
if self.n_estimators < 3:
|
||||
raise ValueError(
|
||||
f"n_estimators must be greater than 9... got (n_estimators=\
|
||||
f"n_estimators must be greater than 3... got (n_estimators=\
|
||||
{self.n_estimators:f})"
|
||||
)
|
||||
# the rest of parameters are checked in estimator
|
||||
@ -115,13 +116,13 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
def _get_bootstrap_n_samples(self, n_samples) -> int:
|
||||
if self.max_samples is None:
|
||||
return n_samples
|
||||
if type(self.max_samples) == int:
|
||||
if isinstance(self.max_samples, int):
|
||||
if not (1 <= self.max_samples <= n_samples):
|
||||
message = f"max_samples should be in the range 1 to \
|
||||
{n_samples} but got {self.max_samples}"
|
||||
raise ValueError(message)
|
||||
return self.max_samples
|
||||
if type(self.max_samples) == float:
|
||||
if isinstance(self.max_samples, float):
|
||||
if not (0 < self.max_samples < 1):
|
||||
message = f"max_samples should be in the range (0, 1)\
|
||||
but got {self.max_samples}"
|
||||
@ -137,7 +138,10 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
check_is_fitted(self, ["estimators_"])
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
return np.ones((X.shape[0]),)
|
||||
result = np.empty((X.shape[0], self.n_estimators))
|
||||
for index, tree in enumerate(self.estimators_):
|
||||
result[:, index] = tree.predict(X)
|
||||
return mode(result, axis=1).mode.ravel()
|
||||
|
||||
def score(
|
||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||
@ -148,7 +152,7 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
X, y = check_X_y(X, y)
|
||||
y_pred = self.predict(X).reshape(y.shape)
|
||||
# Compute accuracy for each possible representation
|
||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||
_, y_true, y_pred = _check_targets(y, y_pred)
|
||||
check_consistent_length(y_true, y_pred, sample_weight)
|
||||
score = y_true == y_pred
|
||||
return _weighted_sum(score, sample_weight, normalize=True)
|
||||
|
@ -52,24 +52,37 @@ class Odte_test(unittest.TestCase):
|
||||
self.assertLessEqual(value, 1000)
|
||||
|
||||
def test_bogus_n_estimator(self):
|
||||
values = [0, -1]
|
||||
values = [0, -1, 2]
|
||||
for n_estimators in values:
|
||||
with self.assertRaises(ValueError):
|
||||
tclf = Odte(n_estimators=n_estimators)
|
||||
tclf.fit(*load_dataset(self._random_state))
|
||||
|
||||
def test_simple_predict(self):
|
||||
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2]
|
||||
expected = [0, 1, 1, 0]
|
||||
tclf = Odte(
|
||||
random_state=self._random_state, n_estimators=10, kernel="rbf"
|
||||
)
|
||||
computed = tclf.fit(X, y).predict(X)
|
||||
self.assertListEqual(expected, computed.tolist())
|
||||
|
||||
def test_predict(self):
|
||||
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
X, y = load_dataset(self._random_state)
|
||||
expected = np.ones(y.shape[0])
|
||||
tclf = Odte(random_state=self._random_state)
|
||||
expected = y
|
||||
tclf = Odte(
|
||||
random_state=self._random_state, n_estimators=10, kernel="linear"
|
||||
)
|
||||
computed = tclf.fit(X, y).predict(X)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
self.assertListEqual(expected[:27].tolist(), computed[:27].tolist())
|
||||
|
||||
def test_score(self):
|
||||
X, y = load_dataset(self._random_state)
|
||||
expected = 0.5
|
||||
tclf = Odte(random_state=self._random_state)
|
||||
expected = 0.9526666666666667
|
||||
tclf = Odte(random_state=self._random_state, n_estimators=10)
|
||||
computed = tclf.fit(X, y).score(X, y)
|
||||
self.assertAlmostEqual(expected, computed)
|
||||
|
@ -2,4 +2,4 @@ numpy
|
||||
scikit-learn
|
||||
pandas
|
||||
ipympl
|
||||
stree
|
||||
git+https://github.com/doctorado-ml/stree
|
Loading…
x
Reference in New Issue
Block a user