mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
Refactor score method using base class implementation
This commit is contained in:
@@ -16,7 +16,6 @@ import numpy as np
|
|||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.svm import SVC, LinearSVC
|
from sklearn.svm import SVC, LinearSVC
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.utils import check_consistent_length
|
|
||||||
from sklearn.utils.multiclass import check_classification_targets
|
from sklearn.utils.multiclass import check_classification_targets
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
from sklearn.utils.validation import (
|
from sklearn.utils.validation import (
|
||||||
@@ -25,7 +24,6 @@ from sklearn.utils.validation import (
|
|||||||
check_is_fitted,
|
check_is_fitted,
|
||||||
_check_sample_weight,
|
_check_sample_weight,
|
||||||
)
|
)
|
||||||
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
|
||||||
|
|
||||||
|
|
||||||
class Snode:
|
class Snode:
|
||||||
@@ -832,36 +830,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
)
|
)
|
||||||
return self.classes_[result]
|
return self.classes_[result]
|
||||||
|
|
||||||
def score(
|
|
||||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
|
||||||
) -> float:
|
|
||||||
"""Compute accuracy of the prediction
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
X : np.array
|
|
||||||
dataset of samples to make predictions
|
|
||||||
y : np.array
|
|
||||||
samples labels
|
|
||||||
sample_weight : np.array, optional
|
|
||||||
weights of the samples. Rescale C per sample, by default None
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
accuracy of the prediction
|
|
||||||
"""
|
|
||||||
# sklearn check
|
|
||||||
check_is_fitted(self)
|
|
||||||
check_classification_targets(y)
|
|
||||||
X, y = check_X_y(X, y)
|
|
||||||
y_pred = self.predict(X).reshape(y.shape)
|
|
||||||
# Compute accuracy for each possible representation
|
|
||||||
_, y_true, y_pred = _check_targets(y, y_pred)
|
|
||||||
check_consistent_length(y_true, y_pred, sample_weight)
|
|
||||||
score = y_true == y_pred
|
|
||||||
return _weighted_sum(score, sample_weight, normalize=True)
|
|
||||||
|
|
||||||
def nodes_leaves(self) -> tuple:
|
def nodes_leaves(self) -> tuple:
|
||||||
"""Compute the number of nodes and leaves in the built tree
|
"""Compute the number of nodes and leaves in the built tree
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user