From ef6b5b08d5b08adb785aef7c898d5af100f75556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sat, 13 Jun 2020 01:18:51 +0200 Subject: [PATCH] Predict and score structure --- .gitignore | 4 ++++ README.md | 6 ++++++ odte/Odte.py | 8 +++++--- odte/tests/Odte_tests.py | 26 ++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index b6e4761..d50268a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea +.vscode +.pre-commit-config.yaml \ No newline at end of file diff --git a/README.md b/README.md index 1583470..f776610 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,8 @@ +[![Codeship Status for Doctorado-ML/Odte](https://app.codeship.com/projects/c279cef0-8f1b-0138-f3d2-5e67174268f2/status?branch=master)](https://app.codeship.com/projects/399830) + +[![codecov](https://codecov.io/gh/Doctorado-ML/odte/branch/master/graph/badge.svg)](https://codecov.io/gh/Doctorado-ML/odte) + +[![Codacy Badge](https://app.codacy.com/project/badge/Grade/c85f935ac6a0482ab67d3ebed4611459)](https://www.codacy.com/gh/Doctorado-ML/Odte?utm_source=github.com&utm_medium=referral&utm_content=Doctorado-ML/Odte&utm_campaign=Badge_Grade) + # odte Oblique Decision Tree Ensemble diff --git a/odte/Odte.py b/odte/Odte.py index e1cbfa0..aab1f70 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -91,6 +91,7 @@ class Odte(BaseEstimator, ClassifierMixin): self.n_classes_ = self.classes_.shape[0] self.estimators_ = [] self._train(X, y, sample_weight) + return self def _train( self, X: np.array, y: np.array, sample_weight: np.array @@ -109,7 +110,7 @@ class Odte(BaseEstimator, ClassifierMixin): weights_update = np.bincount(indices, minlength=n_samples) current_weights = weights * weights_update # train the classifier - clf.fit(X[indices, :], y[indices, :], current_weights[indices, :]) + clf.fit(X[indices, :], y[indices], current_weights[indices]) def _get_bootstrap_n_samples(self, n_samples) -> int: if self.max_samples is None: @@ -131,14 +132,15 @@ class Odte(BaseEstimator, ClassifierMixin): {type(self.max_samples)}" ) - def predict(self, X: np.array): + def predict(self, X: np.array) -> np.array: # todo check_is_fitted(self, ["estimators_"]) # Input validation X = check_array(X) + return np.ones((X.shape[0]),) def score( - self, X: np.array, y: np.array, sample_weight: np.array + self, X: np.array, y: np.array, sample_weight: np.array = None ) -> float: # todo check_is_fitted(self, ["estimators_"]) diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 9f4fdee..4372c0e 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -1,6 +1,9 @@ import unittest import numpy as np +import warnings +from sklearn.exceptions import ConvergenceWarning + from odte import Odte from .utils import load_dataset @@ -47,3 +50,26 @@ class Odte_test(unittest.TestCase): for value in computed.tolist(): self.assertGreaterEqual(value, 101) self.assertLessEqual(value, 1000) + + def test_bogus_n_estimator(self): + values = [0, -1] + for n_estimators in values: + with self.assertRaises(ValueError): + tclf = Odte(n_estimators=n_estimators) + tclf.fit(*load_dataset(self._random_state)) + + def test_predict(self): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + X, y = load_dataset(self._random_state) + expected = np.ones(y.shape[0]) + tclf = Odte(random_state=self._random_state) + computed = tclf.fit(X, y).predict(X) + self.assertListEqual(expected.tolist(), computed.tolist()) + + def test_score(self): + X, y = load_dataset(self._random_state) + expected = 0.5 + tclf = Odte(random_state=self._random_state) + computed = tclf.fit(X, y).score(X, y) + self.assertAlmostEqual(expected, computed)