better solution to the sklearn bagging problem

Add better tests
enhance .coveragerc
This commit is contained in:
2020-06-26 11:22:45 +02:00
parent 76723993fd
commit 4b7e4a3fb0
5 changed files with 66 additions and 58 deletions

View File

@@ -40,6 +40,7 @@ class Snode:
features: np.array,
impurity: float,
title: str,
weight: np.ndarray = None,
):
self._clf = clf
self._title = title
@@ -51,7 +52,9 @@ class Snode:
self._up = None
self._class = None
self._feature = None
self._sample_weight = None
self._sample_weight = (
weight if os.environ.get("TESTING", "NS") != "NS" else None
)
self._features = features
self._impurity = impurity
@@ -443,9 +446,6 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight = _check_sample_weight(
sample_weight, X, dtype=np.float64
)
# solve WARNING: class label 0 specified in weight is not found
# in bagging
sample_weight += 1e-5
check_classification_targets(y)
# Initialize computed parameters
self.splitter_ = Splitter(
@@ -505,13 +505,22 @@ class Stree(BaseEstimator, ClassifierMixin):
features=X.shape[1],
impurity=0.0,
title=title + ", <pure>",
weight=sample_weight,
)
# Train the model
clf = self._build_clf()
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
# solve WARNING: class label 0 specified in weight is not found
# in bagging
if any(sample_weight == 0):
indices = sample_weight == 0
y_next = y[~indices]
# touch weights if removing any class
if np.unique(y_next).shape[0] != self.n_classes_:
sample_weight += 1e-5
clf.fit(Xs, y, sample_weight=sample_weight)
impurity = self.splitter_.impurity(y)
node = Snode(clf, X, y, features, impurity, title)
node = Snode(clf, X, y, features, impurity, title, sample_weight)
self.depth_ = max(depth, self.depth_)
self.splitter_.partition(X, node)
X_U, X_D = self.splitter_.part(X)
@@ -526,6 +535,7 @@ class Stree(BaseEstimator, ClassifierMixin):
features=X.shape[1],
impurity=impurity,
title=title + ", <cgaf>",
weight=sample_weight,
)
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))