mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
Get only 3 sets for best split
Fix flaky test in Splitter_test
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -131,3 +131,5 @@ dmypy.json
|
|||||||
.idea
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
|
|
||||||
|
**.csv
|
File diff suppressed because one or more lines are too long
@@ -218,7 +218,7 @@ class Splitter:
|
|||||||
imp_dn = self.criterion_function(labels_dn)
|
imp_dn = self.criterion_function(labels_dn)
|
||||||
samples = card_up + card_dn
|
samples = card_up + card_dn
|
||||||
if samples == 0:
|
if samples == 0:
|
||||||
return 0
|
return 0.0
|
||||||
else:
|
else:
|
||||||
result = (
|
result = (
|
||||||
imp_prev
|
imp_prev
|
||||||
@@ -244,7 +244,6 @@ class Splitter:
|
|||||||
if gain > max_gain:
|
if gain > max_gain:
|
||||||
max_gain = gain
|
max_gain = gain
|
||||||
selected = feature_set
|
selected = feature_set
|
||||||
|
|
||||||
return selected if selected is not None else feature_set
|
return selected if selected is not None else feature_set
|
||||||
|
|
||||||
def _get_subspaces_set(
|
def _get_subspaces_set(
|
||||||
@@ -257,6 +256,9 @@ class Splitter:
|
|||||||
index = random.randint(0, len(features_sets) - 1)
|
index = random.randint(0, len(features_sets) - 1)
|
||||||
return features_sets[index]
|
return features_sets[index]
|
||||||
else:
|
else:
|
||||||
|
# get only 3 sets at most
|
||||||
|
if len(features_sets) > 3:
|
||||||
|
features_sets = random.sample(features_sets, 3)
|
||||||
return self._select_best_set(dataset, labels, features_sets)
|
return self._select_best_set(dataset, labels, features_sets)
|
||||||
else:
|
else:
|
||||||
return features_sets[0]
|
return features_sets[0]
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import SVC
|
||||||
|
from sklearn.datasets import load_wine
|
||||||
from stree import Splitter
|
from stree import Splitter
|
||||||
from .utils import load_dataset
|
|
||||||
|
|
||||||
|
|
||||||
class Splitter_test(unittest.TestCase):
|
class Splitter_test(unittest.TestCase):
|
||||||
@@ -15,7 +15,7 @@ class Splitter_test(unittest.TestCase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(
|
def build(
|
||||||
clf=LinearSVC(),
|
clf=SVC,
|
||||||
min_samples_split=0,
|
min_samples_split=0,
|
||||||
splitter_type="random",
|
splitter_type="random",
|
||||||
criterion="gini",
|
criterion="gini",
|
||||||
@@ -23,7 +23,7 @@ class Splitter_test(unittest.TestCase):
|
|||||||
random_state=None,
|
random_state=None,
|
||||||
):
|
):
|
||||||
return Splitter(
|
return Splitter(
|
||||||
clf=clf,
|
clf=clf(random_state=random_state, kernel="rbf"),
|
||||||
min_samples_split=min_samples_split,
|
min_samples_split=min_samples_split,
|
||||||
splitter_type=splitter_type,
|
splitter_type=splitter_type,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
@@ -43,7 +43,7 @@ class Splitter_test(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.build(criteria="duck")
|
self.build(criteria="duck")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.build(clf=None)
|
_ = Splitter(clf=None)
|
||||||
for splitter_type in ["best", "random"]:
|
for splitter_type in ["best", "random"]:
|
||||||
for criterion in ["gini", "entropy"]:
|
for criterion in ["gini", "entropy"]:
|
||||||
for criteria in [
|
for criteria in [
|
||||||
@@ -178,26 +178,23 @@ class Splitter_test(unittest.TestCase):
|
|||||||
|
|
||||||
def test_splitter_parameter(self):
|
def test_splitter_parameter(self):
|
||||||
expected_values = [
|
expected_values = [
|
||||||
[1, 2], # random gini min_distance
|
[2, 3, 5, 7], # best entropy min_distance
|
||||||
[0, 2], # random gini max_samples
|
[0, 2, 4, 5], # best entropy max_samples
|
||||||
[1, 3], # random gini max_distance
|
[0, 2, 8, 12], # best entropy max_distance
|
||||||
[1, 2], # random entropy min_distance
|
[1, 2, 5, 12], # best gini min_distance
|
||||||
[1, 2], # random entropy max_samples
|
[0, 3, 4, 10], # best gini max_samples
|
||||||
[0, 2], # random entropy max_distance
|
[1, 2, 9, 12], # best gini max_distance
|
||||||
[1, 2], # best gini min_distance
|
[3, 9, 11, 12], # random entropy min_distance
|
||||||
[0, 2], # best gini max_samples
|
[1, 5, 6, 9], # random entropy max_samples
|
||||||
[0, 2], # best gini max_distance
|
[1, 2, 4, 8], # random entropy max_distance
|
||||||
[0, 1], # best entropy min_distance
|
[2, 6, 7, 12], # random gini min_distance
|
||||||
[0, 1], # best entropy max_samples
|
[3, 9, 10, 11], # random gini max_samples
|
||||||
[0, 1], # best entropy max_distance
|
[2, 5, 8, 12], # random gini max_distance
|
||||||
]
|
]
|
||||||
X, y = load_dataset(self._random_state, n_features=6, n_classes=3)
|
X, y = load_wine(return_X_y=True)
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
|
|
||||||
X, y = load_iris(return_X_y=True)
|
|
||||||
rn = 0
|
rn = 0
|
||||||
for splitter_type in ["random", "best"]:
|
for splitter_type in ["best", "random"]:
|
||||||
for criterion in ["gini", "entropy"]:
|
for criterion in ["entropy", "gini"]:
|
||||||
for criteria in [
|
for criteria in [
|
||||||
"min_distance",
|
"min_distance",
|
||||||
"max_samples",
|
"max_samples",
|
||||||
@@ -207,11 +204,11 @@ class Splitter_test(unittest.TestCase):
|
|||||||
splitter_type=splitter_type,
|
splitter_type=splitter_type,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
criteria=criteria,
|
criteria=criteria,
|
||||||
random_state=rn,
|
|
||||||
)
|
)
|
||||||
rn += 3
|
|
||||||
expected = expected_values.pop(0)
|
expected = expected_values.pop(0)
|
||||||
dataset, computed = tcl.get_subspace(X, y, max_features=2)
|
random.seed(rn)
|
||||||
|
rn += 1
|
||||||
|
dataset, computed = tcl.get_subspace(X, y, max_features=4)
|
||||||
# print(
|
# print(
|
||||||
# "{}, # {:7s}{:8s}{:15s}".format(
|
# "{}, # {:7s}{:8s}{:15s}".format(
|
||||||
# list(computed), splitter_type, criterion,
|
# list(computed), splitter_type, criterion,
|
||||||
|
Reference in New Issue
Block a user