mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-10 15:56:51 +00:00
Predict and score structure
This commit is contained in:
parent
8d4fdd4ee8
commit
ef6b5b08d5
4
.gitignore
vendored
4
.gitignore
vendored
@ -127,3 +127,7 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
.idea
|
||||
.vscode
|
||||
.pre-commit-config.yaml
|
@ -1,2 +1,8 @@
|
||||
[](https://app.codeship.com/projects/399830)
|
||||
|
||||
[](https://codecov.io/gh/Doctorado-ML/odte)
|
||||
|
||||
[](https://www.codacy.com/gh/Doctorado-ML/Odte?utm_source=github.com&utm_medium=referral&utm_content=Doctorado-ML/Odte&utm_campaign=Badge_Grade)
|
||||
|
||||
# odte
|
||||
Oblique Decision Tree Ensemble
|
||||
|
@ -91,6 +91,7 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
self.n_classes_ = self.classes_.shape[0]
|
||||
self.estimators_ = []
|
||||
self._train(X, y, sample_weight)
|
||||
return self
|
||||
|
||||
def _train(
|
||||
self, X: np.array, y: np.array, sample_weight: np.array
|
||||
@ -109,7 +110,7 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
weights_update = np.bincount(indices, minlength=n_samples)
|
||||
current_weights = weights * weights_update
|
||||
# train the classifier
|
||||
clf.fit(X[indices, :], y[indices, :], current_weights[indices, :])
|
||||
clf.fit(X[indices, :], y[indices], current_weights[indices])
|
||||
|
||||
def _get_bootstrap_n_samples(self, n_samples) -> int:
|
||||
if self.max_samples is None:
|
||||
@ -131,14 +132,15 @@ class Odte(BaseEstimator, ClassifierMixin):
|
||||
{type(self.max_samples)}"
|
||||
)
|
||||
|
||||
def predict(self, X: np.array):
|
||||
def predict(self, X: np.array) -> np.array:
|
||||
# todo
|
||||
check_is_fitted(self, ["estimators_"])
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
return np.ones((X.shape[0]),)
|
||||
|
||||
def score(
|
||||
self, X: np.array, y: np.array, sample_weight: np.array
|
||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||
) -> float:
|
||||
# todo
|
||||
check_is_fitted(self, ["estimators_"])
|
||||
|
@ -1,6 +1,9 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import warnings
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
|
||||
from odte import Odte
|
||||
from .utils import load_dataset
|
||||
|
||||
@ -47,3 +50,26 @@ class Odte_test(unittest.TestCase):
|
||||
for value in computed.tolist():
|
||||
self.assertGreaterEqual(value, 101)
|
||||
self.assertLessEqual(value, 1000)
|
||||
|
||||
def test_bogus_n_estimator(self):
|
||||
values = [0, -1]
|
||||
for n_estimators in values:
|
||||
with self.assertRaises(ValueError):
|
||||
tclf = Odte(n_estimators=n_estimators)
|
||||
tclf.fit(*load_dataset(self._random_state))
|
||||
|
||||
def test_predict(self):
|
||||
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
X, y = load_dataset(self._random_state)
|
||||
expected = np.ones(y.shape[0])
|
||||
tclf = Odte(random_state=self._random_state)
|
||||
computed = tclf.fit(X, y).predict(X)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
|
||||
def test_score(self):
|
||||
X, y = load_dataset(self._random_state)
|
||||
expected = 0.5
|
||||
tclf = Odte(random_state=self._random_state)
|
||||
computed = tclf.fit(X, y).score(X, y)
|
||||
self.assertAlmostEqual(expected, computed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user