Refactor score method using base class implementation

This commit is contained in:
2021-04-19 13:52:36 +02:00
parent 045e2fd446
commit fec094a75f

View File

@@ -16,7 +16,6 @@ import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_consistent_length
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import ( from sklearn.utils.validation import (
@@ -25,7 +24,6 @@ from sklearn.utils.validation import (
check_is_fitted, check_is_fitted,
_check_sample_weight, _check_sample_weight,
) )
from sklearn.metrics._classification import _weighted_sum, _check_targets
class Snode: class Snode:
@@ -832,36 +830,6 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
return self.classes_[result] return self.classes_[result]
def score(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> float:
"""Compute accuracy of the prediction
Parameters
----------
X : np.array
dataset of samples to make predictions
y : np.array
samples labels
sample_weight : np.array, optional
weights of the samples. Rescale C per sample, by default None
Returns
-------
float
accuracy of the prediction
"""
# sklearn check
check_is_fitted(self)
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation
_, y_true, y_pred = _check_targets(y, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True)
def nodes_leaves(self) -> tuple: def nodes_leaves(self) -> tuple:
"""Compute the number of nodes and leaves in the built tree """Compute the number of nodes and leaves in the built tree