mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
Implement hyperparam. context based normalization (#32)
This commit is contained in:
committed by
GitHub
parent
b55f59a3ec
commit
8a18c998df
@@ -15,6 +15,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import check_consistent_length
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
@@ -41,6 +42,7 @@ class Snode:
|
||||
impurity: float,
|
||||
title: str,
|
||||
weight: np.ndarray = None,
|
||||
scaler: StandardScaler = None,
|
||||
):
|
||||
self._clf = clf
|
||||
self._title = title
|
||||
@@ -58,6 +60,7 @@ class Snode:
|
||||
self._features = features
|
||||
self._impurity = impurity
|
||||
self._partition_column: int = -1
|
||||
self._scaler = scaler
|
||||
|
||||
@classmethod
|
||||
def copy(cls, node: "Snode") -> "Snode":
|
||||
@@ -68,6 +71,8 @@ class Snode:
|
||||
node._features,
|
||||
node._impurity,
|
||||
node._title,
|
||||
node._sample_weight,
|
||||
node._scaler,
|
||||
)
|
||||
|
||||
def set_partition_column(self, col: int):
|
||||
@@ -178,6 +183,7 @@ class Splitter:
|
||||
criteria: str = None,
|
||||
min_samples_split: int = None,
|
||||
random_state=None,
|
||||
normalize=False,
|
||||
):
|
||||
self._clf = clf
|
||||
self._random_state = random_state
|
||||
@@ -187,6 +193,7 @@ class Splitter:
|
||||
self._min_samples_split = min_samples_split
|
||||
self._criteria = criteria
|
||||
self._splitter_type = splitter_type
|
||||
self._normalize = normalize
|
||||
|
||||
if clf is None:
|
||||
raise ValueError(f"clf has to be a sklearn estimator, got({clf})")
|
||||
@@ -486,8 +493,7 @@ class Splitter:
|
||||
origin[down] if any(down) else None,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _distances(node: Snode, data: np.ndarray) -> np.array:
|
||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
||||
"""Compute distances of the samples to the hyperplane of the node
|
||||
|
||||
Parameters
|
||||
@@ -503,7 +509,10 @@ class Splitter:
|
||||
array of shape (m, nc) with the distances of every sample to
|
||||
the hyperplane of every class. nc = # of classes
|
||||
"""
|
||||
return node._clf.decision_function(data[:, node._features])
|
||||
X_transformed = data[:, node._features]
|
||||
if self._normalize:
|
||||
X_transformed = node._scaler.transform(X_transformed)
|
||||
return node._clf.decision_function(X_transformed)
|
||||
|
||||
|
||||
class Stree(BaseEstimator, ClassifierMixin):
|
||||
@@ -529,6 +538,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
min_samples_split: int = 0,
|
||||
max_features=None,
|
||||
splitter: str = "random",
|
||||
normalize: bool = False,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
@@ -543,6 +553,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.max_features = max_features
|
||||
self.criterion = criterion
|
||||
self.splitter = splitter
|
||||
self.normalize = normalize
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
@@ -606,6 +617,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
criteria=self.split_criteria,
|
||||
random_state=self.random_state,
|
||||
min_samples_split=self.min_samples_split,
|
||||
normalize=self.normalize,
|
||||
)
|
||||
if self.random_state is not None:
|
||||
random.seed(self.random_state)
|
||||
@@ -660,7 +672,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
y = y[~indices_zero]
|
||||
sample_weight = sample_weight[~indices_zero]
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight)
|
||||
scaler = StandardScaler()
|
||||
node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight, scaler)
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
node.set_title(title + ", <pure>")
|
||||
@@ -668,6 +681,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# Train the model
|
||||
clf = self._build_clf()
|
||||
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
|
||||
if self.normalize:
|
||||
scaler.fit(Xs)
|
||||
Xs = scaler.transform(Xs)
|
||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||
node.set_impurity(self.splitter_.partition_impurity(y))
|
||||
node.set_classifier(clf)
|
||||
|
Reference in New Issue
Block a user