mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Update Readme
Add max_features > n_features test Add make doc
This commit is contained in:
@@ -653,12 +653,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.n_features_ = X.shape[1]
|
||||
self.n_features_in_ = X.shape[1]
|
||||
self.max_features_ = self._initialize_max_features()
|
||||
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
||||
self.tree_ = self._train(X, y, sample_weight, 1, "root")
|
||||
self.X_ = X
|
||||
self.y_ = y
|
||||
return self
|
||||
|
||||
def train(
|
||||
def _train(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
@@ -723,10 +723,10 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
node.make_predictor()
|
||||
return node
|
||||
node.set_up(
|
||||
self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||
self._train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
|
||||
)
|
||||
node.set_down(
|
||||
self.train(
|
||||
self._train(
|
||||
X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})"
|
||||
)
|
||||
)
|
||||
@@ -892,6 +892,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
elif self.max_features is None:
|
||||
max_features = self.n_features_
|
||||
elif isinstance(self.max_features, numbers.Integral):
|
||||
if self.max_features > self.n_features_:
|
||||
raise ValueError(
|
||||
"Invalid value for max_features. "
|
||||
"It can not be greater than number of features "
|
||||
f"({self.n_features_})"
|
||||
)
|
||||
max_features = self.max_features
|
||||
else: # float
|
||||
if self.max_features > 0.0:
|
||||
|
@@ -6,6 +6,5 @@ __author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020-2021, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT License"
|
||||
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
||||
__url__ = "https://github.com/doctorado-ml/stree"
|
||||
|
||||
__all__ = ["Stree", "Snode", "Siterator", "Splitter"]
|
||||
|
@@ -269,6 +269,12 @@ class Stree_test(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = clf._initialize_max_features()
|
||||
|
||||
def test_wrong_max_features(self):
|
||||
X, y = load_dataset(n_features=15)
|
||||
clf = Stree(max_features=16)
|
||||
with self.assertRaises(ValueError):
|
||||
clf.fit(X, y)
|
||||
|
||||
def test_get_subspaces(self):
|
||||
dataset = np.random.random((10, 16))
|
||||
y = np.random.randint(0, 2, 10)
|
||||
|
Reference in New Issue
Block a user